Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 1 | //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "IRModule.h" |
| 10 | |
| 11 | #include "PybindUtils.h" |
| 12 | |
| 13 | #include "mlir-c/AffineMap.h" |
| 14 | #include "mlir-c/Bindings/Python/Interop.h" |
| 15 | #include "mlir-c/IntegerSet.h" |
| 16 | |
| 17 | namespace py = pybind11; |
| 18 | using namespace mlir; |
| 19 | using namespace mlir::python; |
| 20 | |
| 21 | using llvm::SmallVector; |
| 22 | using llvm::StringRef; |
| 23 | using llvm::Twine; |
| 24 | |
| 25 | static const char kDumpDocstring[] = |
| 26 | R"(Dumps a debug representation of the object to stderr.)"; |
| 27 | |
| 28 | /// Attempts to populate `result` with the content of `list` casted to the |
| 29 | /// appropriate type (Python and C types are provided as template arguments). |
| 30 | /// Throws errors in case of failure, using "action" to describe what the caller |
| 31 | /// was attempting to do. |
| 32 | template <typename PyType, typename CType> |
| 33 | static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, |
| 34 | StringRef action) { |
| 35 | result.reserve(py::len(list)); |
| 36 | for (py::handle item : list) { |
| 37 | try { |
| 38 | result.push_back(item.cast<PyType>()); |
| 39 | } catch (py::cast_error &err) { |
| 40 | std::string msg = (llvm::Twine("Invalid expression when ") + action + |
| 41 | " (" + err.what() + ")") |
| 42 | .str(); |
| 43 | throw py::cast_error(msg); |
| 44 | } catch (py::reference_cast_error &err) { |
| 45 | std::string msg = (llvm::Twine("Invalid expression (None?) when ") + |
| 46 | action + " (" + err.what() + ")") |
| 47 | .str(); |
| 48 | throw py::cast_error(msg); |
| 49 | } |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | template <typename PermutationTy> |
| 54 | static bool isPermutation(std::vector<PermutationTy> permutation) { |
| 55 | llvm::SmallVector<bool, 8> seen(permutation.size(), false); |
| 56 | for (auto val : permutation) { |
| 57 | if (val < permutation.size()) { |
| 58 | if (seen[val]) |
| 59 | return false; |
| 60 | seen[val] = true; |
| 61 | continue; |
| 62 | } |
| 63 | return false; |
| 64 | } |
| 65 | return true; |
| 66 | } |
| 67 | |
| 68 | namespace { |
| 69 | |
| 70 | /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr |
| 71 | /// and should be castable from it. Intermediate hierarchy classes can be |
| 72 | /// modeled by specifying BaseTy. |
| 73 | template <typename DerivedTy, typename BaseTy = PyAffineExpr> |
| 74 | class PyConcreteAffineExpr : public BaseTy { |
| 75 | public: |
| 76 | // Derived classes must define statics for: |
| 77 | // IsAFunctionTy isaFunction |
| 78 | // const char *pyClassName |
| 79 | // and redefine bindDerived. |
| 80 | using ClassTy = py::class_<DerivedTy, BaseTy>; |
| 81 | using IsAFunctionTy = bool (*)(MlirAffineExpr); |
| 82 | |
| 83 | PyConcreteAffineExpr() = default; |
| 84 | PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) |
| 85 | : BaseTy(std::move(contextRef), affineExpr) {} |
| 86 | PyConcreteAffineExpr(PyAffineExpr &orig) |
| 87 | : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} |
| 88 | |
| 89 | static MlirAffineExpr castFrom(PyAffineExpr &orig) { |
| 90 | if (!DerivedTy::isaFunction(orig)) { |
| 91 | auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); |
| 92 | throw SetPyError(PyExc_ValueError, |
| 93 | Twine("Cannot cast affine expression to ") + |
| 94 | DerivedTy::pyClassName + " (from " + origRepr + ")"); |
| 95 | } |
| 96 | return orig; |
| 97 | } |
| 98 | |
| 99 | static void bind(py::module &m) { |
Stella Laurenzo | f05ff4f | 2021-08-23 20:01:07 -0700 | [diff] [blame] | 100 | auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); |
Stella Laurenzo | a6e7d02 | 2021-11-28 14:08:06 -0800 | [diff] [blame] | 101 | cls.def(py::init<PyAffineExpr &>(), py::arg("expr")); |
| 102 | cls.def_static( |
| 103 | "isinstance", |
| 104 | [](PyAffineExpr &otherAffineExpr) -> bool { |
| 105 | return DerivedTy::isaFunction(otherAffineExpr); |
| 106 | }, |
| 107 | py::arg("other")); |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 108 | DerivedTy::bindDerived(cls); |
| 109 | } |
| 110 | |
| 111 | /// Implemented by derived classes to add methods to the Python subclass. |
| 112 | static void bindDerived(ClassTy &m) {} |
| 113 | }; |
| 114 | |
| 115 | class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { |
| 116 | public: |
| 117 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; |
| 118 | static constexpr const char *pyClassName = "AffineConstantExpr"; |
| 119 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 120 | |
| 121 | static PyAffineConstantExpr get(intptr_t value, |
| 122 | DefaultingPyMlirContext context) { |
| 123 | MlirAffineExpr affineExpr = |
| 124 | mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); |
| 125 | return PyAffineConstantExpr(context->getRef(), affineExpr); |
| 126 | } |
| 127 | |
| 128 | static void bindDerived(ClassTy &c) { |
| 129 | c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), |
| 130 | py::arg("context") = py::none()); |
| 131 | c.def_property_readonly("value", [](PyAffineConstantExpr &self) { |
| 132 | return mlirAffineConstantExprGetValue(self); |
| 133 | }); |
| 134 | } |
| 135 | }; |
| 136 | |
| 137 | class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { |
| 138 | public: |
| 139 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; |
| 140 | static constexpr const char *pyClassName = "AffineDimExpr"; |
| 141 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 142 | |
| 143 | static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { |
| 144 | MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); |
| 145 | return PyAffineDimExpr(context->getRef(), affineExpr); |
| 146 | } |
| 147 | |
| 148 | static void bindDerived(ClassTy &c) { |
| 149 | c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), |
| 150 | py::arg("context") = py::none()); |
| 151 | c.def_property_readonly("position", [](PyAffineDimExpr &self) { |
| 152 | return mlirAffineDimExprGetPosition(self); |
| 153 | }); |
| 154 | } |
| 155 | }; |
| 156 | |
| 157 | class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { |
| 158 | public: |
| 159 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; |
| 160 | static constexpr const char *pyClassName = "AffineSymbolExpr"; |
| 161 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 162 | |
| 163 | static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { |
| 164 | MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); |
| 165 | return PyAffineSymbolExpr(context->getRef(), affineExpr); |
| 166 | } |
| 167 | |
| 168 | static void bindDerived(ClassTy &c) { |
| 169 | c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), |
| 170 | py::arg("context") = py::none()); |
| 171 | c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { |
| 172 | return mlirAffineSymbolExprGetPosition(self); |
| 173 | }); |
| 174 | } |
| 175 | }; |
| 176 | |
| 177 | class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { |
| 178 | public: |
| 179 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; |
| 180 | static constexpr const char *pyClassName = "AffineBinaryExpr"; |
| 181 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 182 | |
| 183 | PyAffineExpr lhs() { |
| 184 | MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); |
| 185 | return PyAffineExpr(getContext(), lhsExpr); |
| 186 | } |
| 187 | |
| 188 | PyAffineExpr rhs() { |
| 189 | MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); |
| 190 | return PyAffineExpr(getContext(), rhsExpr); |
| 191 | } |
| 192 | |
| 193 | static void bindDerived(ClassTy &c) { |
| 194 | c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); |
| 195 | c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); |
| 196 | } |
| 197 | }; |
| 198 | |
| 199 | class PyAffineAddExpr |
| 200 | : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { |
| 201 | public: |
| 202 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; |
| 203 | static constexpr const char *pyClassName = "AffineAddExpr"; |
| 204 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 205 | |
| 206 | static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { |
| 207 | MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); |
| 208 | return PyAffineAddExpr(lhs.getContext(), expr); |
| 209 | } |
| 210 | |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 211 | static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 212 | MlirAffineExpr expr = mlirAffineAddExprGet( |
| 213 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 214 | return PyAffineAddExpr(lhs.getContext(), expr); |
| 215 | } |
| 216 | |
| 217 | static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 218 | MlirAffineExpr expr = mlirAffineAddExprGet( |
| 219 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 220 | return PyAffineAddExpr(rhs.getContext(), expr); |
| 221 | } |
| 222 | |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 223 | static void bindDerived(ClassTy &c) { |
| 224 | c.def_static("get", &PyAffineAddExpr::get); |
| 225 | } |
| 226 | }; |
| 227 | |
| 228 | class PyAffineMulExpr |
| 229 | : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { |
| 230 | public: |
| 231 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; |
| 232 | static constexpr const char *pyClassName = "AffineMulExpr"; |
| 233 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 234 | |
| 235 | static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { |
| 236 | MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); |
| 237 | return PyAffineMulExpr(lhs.getContext(), expr); |
| 238 | } |
| 239 | |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 240 | static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 241 | MlirAffineExpr expr = mlirAffineMulExprGet( |
| 242 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 243 | return PyAffineMulExpr(lhs.getContext(), expr); |
| 244 | } |
| 245 | |
| 246 | static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 247 | MlirAffineExpr expr = mlirAffineMulExprGet( |
| 248 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 249 | return PyAffineMulExpr(rhs.getContext(), expr); |
| 250 | } |
| 251 | |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 252 | static void bindDerived(ClassTy &c) { |
| 253 | c.def_static("get", &PyAffineMulExpr::get); |
| 254 | } |
| 255 | }; |
| 256 | |
| 257 | class PyAffineModExpr |
| 258 | : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { |
| 259 | public: |
| 260 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; |
| 261 | static constexpr const char *pyClassName = "AffineModExpr"; |
| 262 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 263 | |
| 264 | static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { |
| 265 | MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); |
| 266 | return PyAffineModExpr(lhs.getContext(), expr); |
| 267 | } |
| 268 | |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 269 | static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 270 | MlirAffineExpr expr = mlirAffineModExprGet( |
| 271 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 272 | return PyAffineModExpr(lhs.getContext(), expr); |
| 273 | } |
| 274 | |
| 275 | static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 276 | MlirAffineExpr expr = mlirAffineModExprGet( |
| 277 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 278 | return PyAffineModExpr(rhs.getContext(), expr); |
| 279 | } |
| 280 | |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 281 | static void bindDerived(ClassTy &c) { |
| 282 | c.def_static("get", &PyAffineModExpr::get); |
| 283 | } |
| 284 | }; |
| 285 | |
| 286 | class PyAffineFloorDivExpr |
| 287 | : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { |
| 288 | public: |
| 289 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; |
| 290 | static constexpr const char *pyClassName = "AffineFloorDivExpr"; |
| 291 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 292 | |
| 293 | static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { |
| 294 | MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); |
| 295 | return PyAffineFloorDivExpr(lhs.getContext(), expr); |
| 296 | } |
| 297 | |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 298 | static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 299 | MlirAffineExpr expr = mlirAffineFloorDivExprGet( |
| 300 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 301 | return PyAffineFloorDivExpr(lhs.getContext(), expr); |
| 302 | } |
| 303 | |
| 304 | static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 305 | MlirAffineExpr expr = mlirAffineFloorDivExprGet( |
| 306 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 307 | return PyAffineFloorDivExpr(rhs.getContext(), expr); |
| 308 | } |
| 309 | |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 310 | static void bindDerived(ClassTy &c) { |
| 311 | c.def_static("get", &PyAffineFloorDivExpr::get); |
| 312 | } |
| 313 | }; |
| 314 | |
| 315 | class PyAffineCeilDivExpr |
| 316 | : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { |
| 317 | public: |
| 318 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; |
| 319 | static constexpr const char *pyClassName = "AffineCeilDivExpr"; |
| 320 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 321 | |
| 322 | static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { |
| 323 | MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); |
| 324 | return PyAffineCeilDivExpr(lhs.getContext(), expr); |
| 325 | } |
| 326 | |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 327 | static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 328 | MlirAffineExpr expr = mlirAffineCeilDivExprGet( |
| 329 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 330 | return PyAffineCeilDivExpr(lhs.getContext(), expr); |
| 331 | } |
| 332 | |
| 333 | static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 334 | MlirAffineExpr expr = mlirAffineCeilDivExprGet( |
| 335 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 336 | return PyAffineCeilDivExpr(rhs.getContext(), expr); |
| 337 | } |
| 338 | |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 339 | static void bindDerived(ClassTy &c) { |
| 340 | c.def_static("get", &PyAffineCeilDivExpr::get); |
| 341 | } |
| 342 | }; |
| 343 | |
| 344 | } // namespace |
| 345 | |
| 346 | bool PyAffineExpr::operator==(const PyAffineExpr &other) { |
| 347 | return mlirAffineExprEqual(affineExpr, other.affineExpr); |
| 348 | } |
| 349 | |
| 350 | py::object PyAffineExpr::getCapsule() { |
| 351 | return py::reinterpret_steal<py::object>( |
| 352 | mlirPythonAffineExprToCapsule(*this)); |
| 353 | } |
| 354 | |
| 355 | PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { |
| 356 | MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); |
| 357 | if (mlirAffineExprIsNull(rawAffineExpr)) |
| 358 | throw py::error_already_set(); |
| 359 | return PyAffineExpr( |
| 360 | PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), |
| 361 | rawAffineExpr); |
| 362 | } |
| 363 | |
| 364 | //------------------------------------------------------------------------------ |
| 365 | // PyAffineMap and utilities. |
| 366 | //------------------------------------------------------------------------------ |
| 367 | namespace { |
| 368 | |
| 369 | /// A list of expressions contained in an affine map. Internally these are |
| 370 | /// stored as a consecutive array leading to inexpensive random access. Both |
| 371 | /// the map and the expression are owned by the context so we need not bother |
| 372 | /// with lifetime extension. |
| 373 | class PyAffineMapExprList |
| 374 | : public Sliceable<PyAffineMapExprList, PyAffineExpr> { |
| 375 | public: |
| 376 | static constexpr const char *pyClassName = "AffineExprList"; |
| 377 | |
| 378 | PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, |
| 379 | intptr_t length = -1, intptr_t step = 1) |
| 380 | : Sliceable(startIndex, |
| 381 | length == -1 ? mlirAffineMapGetNumResults(map) : length, |
| 382 | step), |
| 383 | affineMap(map) {} |
| 384 | |
| 385 | intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } |
| 386 | |
| 387 | PyAffineExpr getElement(intptr_t pos) { |
| 388 | return PyAffineExpr(affineMap.getContext(), |
| 389 | mlirAffineMapGetResult(affineMap, pos)); |
| 390 | } |
| 391 | |
| 392 | PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, |
| 393 | intptr_t step) { |
| 394 | return PyAffineMapExprList(affineMap, startIndex, length, step); |
| 395 | } |
| 396 | |
| 397 | private: |
| 398 | PyAffineMap affineMap; |
| 399 | }; |
| 400 | } // end namespace |
| 401 | |
| 402 | bool PyAffineMap::operator==(const PyAffineMap &other) { |
| 403 | return mlirAffineMapEqual(affineMap, other.affineMap); |
| 404 | } |
| 405 | |
| 406 | py::object PyAffineMap::getCapsule() { |
| 407 | return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); |
| 408 | } |
| 409 | |
| 410 | PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { |
| 411 | MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); |
| 412 | if (mlirAffineMapIsNull(rawAffineMap)) |
| 413 | throw py::error_already_set(); |
| 414 | return PyAffineMap( |
| 415 | PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), |
| 416 | rawAffineMap); |
| 417 | } |
| 418 | |
| 419 | //------------------------------------------------------------------------------ |
| 420 | // PyIntegerSet and utilities. |
| 421 | //------------------------------------------------------------------------------ |
| 422 | namespace { |
| 423 | |
| 424 | class PyIntegerSetConstraint { |
| 425 | public: |
| 426 | PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} |
| 427 | |
| 428 | PyAffineExpr getExpr() { |
| 429 | return PyAffineExpr(set.getContext(), |
| 430 | mlirIntegerSetGetConstraint(set, pos)); |
| 431 | } |
| 432 | |
| 433 | bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } |
| 434 | |
| 435 | static void bind(py::module &m) { |
Stella Laurenzo | f05ff4f | 2021-08-23 20:01:07 -0700 | [diff] [blame] | 436 | py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint", |
| 437 | py::module_local()) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 438 | .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) |
| 439 | .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); |
| 440 | } |
| 441 | |
| 442 | private: |
| 443 | PyIntegerSet set; |
| 444 | intptr_t pos; |
| 445 | }; |
| 446 | |
| 447 | class PyIntegerSetConstraintList |
| 448 | : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { |
| 449 | public: |
| 450 | static constexpr const char *pyClassName = "IntegerSetConstraintList"; |
| 451 | |
| 452 | PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, |
| 453 | intptr_t length = -1, intptr_t step = 1) |
| 454 | : Sliceable(startIndex, |
| 455 | length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, |
| 456 | step), |
| 457 | set(set) {} |
| 458 | |
| 459 | intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } |
| 460 | |
| 461 | PyIntegerSetConstraint getElement(intptr_t pos) { |
| 462 | return PyIntegerSetConstraint(set, pos); |
| 463 | } |
| 464 | |
| 465 | PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, |
| 466 | intptr_t step) { |
| 467 | return PyIntegerSetConstraintList(set, startIndex, length, step); |
| 468 | } |
| 469 | |
| 470 | private: |
| 471 | PyIntegerSet set; |
| 472 | }; |
| 473 | } // namespace |
| 474 | |
| 475 | bool PyIntegerSet::operator==(const PyIntegerSet &other) { |
| 476 | return mlirIntegerSetEqual(integerSet, other.integerSet); |
| 477 | } |
| 478 | |
| 479 | py::object PyIntegerSet::getCapsule() { |
| 480 | return py::reinterpret_steal<py::object>( |
| 481 | mlirPythonIntegerSetToCapsule(*this)); |
| 482 | } |
| 483 | |
| 484 | PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { |
| 485 | MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); |
| 486 | if (mlirIntegerSetIsNull(rawIntegerSet)) |
| 487 | throw py::error_already_set(); |
| 488 | return PyIntegerSet( |
| 489 | PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), |
| 490 | rawIntegerSet); |
| 491 | } |
| 492 | |
| 493 | void mlir::python::populateIRAffine(py::module &m) { |
| 494 | //---------------------------------------------------------------------------- |
| 495 | // Mapping of PyAffineExpr and derived classes. |
| 496 | //---------------------------------------------------------------------------- |
Stella Laurenzo | f05ff4f | 2021-08-23 20:01:07 -0700 | [diff] [blame] | 497 | py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local()) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 498 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
| 499 | &PyAffineExpr::getCapsule) |
| 500 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 501 | .def("__add__", &PyAffineAddExpr::get) |
| 502 | .def("__add__", &PyAffineAddExpr::getRHSConstant) |
| 503 | .def("__radd__", &PyAffineAddExpr::getRHSConstant) |
| 504 | .def("__mul__", &PyAffineMulExpr::get) |
| 505 | .def("__mul__", &PyAffineMulExpr::getRHSConstant) |
| 506 | .def("__rmul__", &PyAffineMulExpr::getRHSConstant) |
| 507 | .def("__mod__", &PyAffineModExpr::get) |
| 508 | .def("__mod__", &PyAffineModExpr::getRHSConstant) |
| 509 | .def("__rmod__", |
| 510 | [](PyAffineExpr &self, intptr_t other) { |
| 511 | return PyAffineModExpr::get( |
| 512 | PyAffineConstantExpr::get(other, *self.getContext().get()), |
| 513 | self); |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 514 | }) |
| 515 | .def("__sub__", |
| 516 | [](PyAffineExpr &self, PyAffineExpr &other) { |
| 517 | auto negOne = |
| 518 | PyAffineConstantExpr::get(-1, *self.getContext().get()); |
| 519 | return PyAffineAddExpr::get(self, |
| 520 | PyAffineMulExpr::get(negOne, other)); |
| 521 | }) |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 522 | .def("__sub__", |
| 523 | [](PyAffineExpr &self, intptr_t other) { |
| 524 | return PyAffineAddExpr::get( |
| 525 | self, |
| 526 | PyAffineConstantExpr::get(-other, *self.getContext().get())); |
| 527 | }) |
| 528 | .def("__rsub__", |
| 529 | [](PyAffineExpr &self, intptr_t other) { |
| 530 | return PyAffineAddExpr::getLHSConstant( |
| 531 | other, PyAffineMulExpr::getLHSConstant(-1, self)); |
| 532 | }) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 533 | .def("__eq__", [](PyAffineExpr &self, |
| 534 | PyAffineExpr &other) { return self == other; }) |
| 535 | .def("__eq__", |
| 536 | [](PyAffineExpr &self, py::object &other) { return false; }) |
| 537 | .def("__str__", |
| 538 | [](PyAffineExpr &self) { |
| 539 | PyPrintAccumulator printAccum; |
| 540 | mlirAffineExprPrint(self, printAccum.getCallback(), |
| 541 | printAccum.getUserData()); |
| 542 | return printAccum.join(); |
| 543 | }) |
| 544 | .def("__repr__", |
| 545 | [](PyAffineExpr &self) { |
| 546 | PyPrintAccumulator printAccum; |
| 547 | printAccum.parts.append("AffineExpr("); |
| 548 | mlirAffineExprPrint(self, printAccum.getCallback(), |
| 549 | printAccum.getUserData()); |
| 550 | printAccum.parts.append(")"); |
| 551 | return printAccum.join(); |
| 552 | }) |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 553 | .def("__hash__", |
| 554 | [](PyAffineExpr &self) { |
| 555 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
| 556 | }) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 557 | .def_property_readonly( |
| 558 | "context", |
| 559 | [](PyAffineExpr &self) { return self.getContext().getObject(); }) |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 560 | .def("compose", |
| 561 | [](PyAffineExpr &self, PyAffineMap &other) { |
| 562 | return PyAffineExpr(self.getContext(), |
| 563 | mlirAffineExprCompose(self, other)); |
| 564 | }) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 565 | .def_static( |
| 566 | "get_add", &PyAffineAddExpr::get, |
| 567 | "Gets an affine expression containing a sum of two expressions.") |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 568 | .def_static("get_add", &PyAffineAddExpr::getLHSConstant, |
| 569 | "Gets an affine expression containing a sum of a constant " |
| 570 | "and another expression.") |
| 571 | .def_static("get_add", &PyAffineAddExpr::getRHSConstant, |
| 572 | "Gets an affine expression containing a sum of an expression " |
| 573 | "and a constant.") |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 574 | .def_static( |
| 575 | "get_mul", &PyAffineMulExpr::get, |
| 576 | "Gets an affine expression containing a product of two expressions.") |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 577 | .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, |
| 578 | "Gets an affine expression containing a product of a " |
| 579 | "constant and another expression.") |
| 580 | .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, |
| 581 | "Gets an affine expression containing a product of an " |
| 582 | "expression and a constant.") |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 583 | .def_static("get_mod", &PyAffineModExpr::get, |
| 584 | "Gets an affine expression containing the modulo of dividing " |
| 585 | "one expression by another.") |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 586 | .def_static("get_mod", &PyAffineModExpr::getLHSConstant, |
| 587 | "Gets a semi-affine expression containing the modulo of " |
| 588 | "dividing a constant by an expression.") |
| 589 | .def_static("get_mod", &PyAffineModExpr::getRHSConstant, |
| 590 | "Gets an affine expression containing the module of dividing" |
| 591 | "an expression by a constant.") |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 592 | .def_static("get_floor_div", &PyAffineFloorDivExpr::get, |
| 593 | "Gets an affine expression containing the rounded-down " |
| 594 | "result of dividing one expression by another.") |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 595 | .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, |
| 596 | "Gets a semi-affine expression containing the rounded-down " |
| 597 | "result of dividing a constant by an expression.") |
| 598 | .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, |
| 599 | "Gets an affine expression containing the rounded-down " |
| 600 | "result of dividing an expression by a constant.") |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 601 | .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, |
| 602 | "Gets an affine expression containing the rounded-up result " |
| 603 | "of dividing one expression by another.") |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 604 | .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, |
| 605 | "Gets a semi-affine expression containing the rounded-up " |
| 606 | "result of dividing a constant by an expression.") |
| 607 | .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, |
| 608 | "Gets an affine expression containing the rounded-up result " |
| 609 | "of dividing an expression by a constant.") |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 610 | .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), |
| 611 | py::arg("context") = py::none(), |
| 612 | "Gets a constant affine expression with the given value.") |
| 613 | .def_static( |
| 614 | "get_dim", &PyAffineDimExpr::get, py::arg("position"), |
| 615 | py::arg("context") = py::none(), |
| 616 | "Gets an affine expression of a dimension at the given position.") |
| 617 | .def_static( |
| 618 | "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), |
| 619 | py::arg("context") = py::none(), |
| 620 | "Gets an affine expression of a symbol at the given position.") |
| 621 | .def( |
| 622 | "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, |
| 623 | kDumpDocstring); |
| 624 | PyAffineConstantExpr::bind(m); |
| 625 | PyAffineDimExpr::bind(m); |
| 626 | PyAffineSymbolExpr::bind(m); |
| 627 | PyAffineBinaryExpr::bind(m); |
| 628 | PyAffineAddExpr::bind(m); |
| 629 | PyAffineMulExpr::bind(m); |
| 630 | PyAffineModExpr::bind(m); |
| 631 | PyAffineFloorDivExpr::bind(m); |
| 632 | PyAffineCeilDivExpr::bind(m); |
| 633 | |
| 634 | //---------------------------------------------------------------------------- |
| 635 | // Mapping of PyAffineMap. |
| 636 | //---------------------------------------------------------------------------- |
Stella Laurenzo | f05ff4f | 2021-08-23 20:01:07 -0700 | [diff] [blame] | 637 | py::class_<PyAffineMap>(m, "AffineMap", py::module_local()) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 638 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
| 639 | &PyAffineMap::getCapsule) |
| 640 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) |
| 641 | .def("__eq__", |
| 642 | [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) |
| 643 | .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) |
| 644 | .def("__str__", |
| 645 | [](PyAffineMap &self) { |
| 646 | PyPrintAccumulator printAccum; |
| 647 | mlirAffineMapPrint(self, printAccum.getCallback(), |
| 648 | printAccum.getUserData()); |
| 649 | return printAccum.join(); |
| 650 | }) |
| 651 | .def("__repr__", |
| 652 | [](PyAffineMap &self) { |
| 653 | PyPrintAccumulator printAccum; |
| 654 | printAccum.parts.append("AffineMap("); |
| 655 | mlirAffineMapPrint(self, printAccum.getCallback(), |
| 656 | printAccum.getUserData()); |
| 657 | printAccum.parts.append(")"); |
| 658 | return printAccum.join(); |
| 659 | }) |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 660 | .def("__hash__", |
| 661 | [](PyAffineMap &self) { |
| 662 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
| 663 | }) |
Nicolas Vasilache | 335d2df | 2021-03-31 09:33:08 +0000 | [diff] [blame] | 664 | .def_static("compress_unused_symbols", |
| 665 | [](py::list affineMaps, DefaultingPyMlirContext context) { |
| 666 | SmallVector<MlirAffineMap> maps; |
| 667 | pyListToVector<PyAffineMap, MlirAffineMap>( |
| 668 | affineMaps, maps, "attempting to create an AffineMap"); |
| 669 | std::vector<MlirAffineMap> compressed(affineMaps.size()); |
| 670 | auto populate = [](void *result, intptr_t idx, |
| 671 | MlirAffineMap m) { |
| 672 | static_cast<MlirAffineMap *>(result)[idx] = (m); |
| 673 | }; |
| 674 | mlirAffineMapCompressUnusedSymbols( |
| 675 | maps.data(), maps.size(), compressed.data(), populate); |
| 676 | std::vector<PyAffineMap> res; |
Mehdi Amini | e2f16be | 2021-10-19 17:13:54 +0000 | [diff] [blame] | 677 | res.reserve(compressed.size()); |
Nicolas Vasilache | 335d2df | 2021-03-31 09:33:08 +0000 | [diff] [blame] | 678 | for (auto m : compressed) |
| 679 | res.push_back(PyAffineMap(context->getRef(), m)); |
| 680 | return res; |
| 681 | }) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 682 | .def_property_readonly( |
| 683 | "context", |
| 684 | [](PyAffineMap &self) { return self.getContext().getObject(); }, |
| 685 | "Context that owns the Affine Map") |
| 686 | .def( |
| 687 | "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, |
| 688 | kDumpDocstring) |
| 689 | .def_static( |
| 690 | "get", |
| 691 | [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, |
| 692 | DefaultingPyMlirContext context) { |
| 693 | SmallVector<MlirAffineExpr> affineExprs; |
| 694 | pyListToVector<PyAffineExpr, MlirAffineExpr>( |
| 695 | exprs, affineExprs, "attempting to create an AffineMap"); |
| 696 | MlirAffineMap map = |
| 697 | mlirAffineMapGet(context->get(), dimCount, symbolCount, |
| 698 | affineExprs.size(), affineExprs.data()); |
| 699 | return PyAffineMap(context->getRef(), map); |
| 700 | }, |
| 701 | py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), |
| 702 | py::arg("context") = py::none(), |
| 703 | "Gets a map with the given expressions as results.") |
| 704 | .def_static( |
| 705 | "get_constant", |
| 706 | [](intptr_t value, DefaultingPyMlirContext context) { |
| 707 | MlirAffineMap affineMap = |
| 708 | mlirAffineMapConstantGet(context->get(), value); |
| 709 | return PyAffineMap(context->getRef(), affineMap); |
| 710 | }, |
| 711 | py::arg("value"), py::arg("context") = py::none(), |
| 712 | "Gets an affine map with a single constant result") |
| 713 | .def_static( |
| 714 | "get_empty", |
| 715 | [](DefaultingPyMlirContext context) { |
| 716 | MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); |
| 717 | return PyAffineMap(context->getRef(), affineMap); |
| 718 | }, |
| 719 | py::arg("context") = py::none(), "Gets an empty affine map.") |
| 720 | .def_static( |
| 721 | "get_identity", |
| 722 | [](intptr_t nDims, DefaultingPyMlirContext context) { |
| 723 | MlirAffineMap affineMap = |
| 724 | mlirAffineMapMultiDimIdentityGet(context->get(), nDims); |
| 725 | return PyAffineMap(context->getRef(), affineMap); |
| 726 | }, |
| 727 | py::arg("n_dims"), py::arg("context") = py::none(), |
| 728 | "Gets an identity map with the given number of dimensions.") |
| 729 | .def_static( |
| 730 | "get_minor_identity", |
| 731 | [](intptr_t nDims, intptr_t nResults, |
| 732 | DefaultingPyMlirContext context) { |
| 733 | MlirAffineMap affineMap = |
| 734 | mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); |
| 735 | return PyAffineMap(context->getRef(), affineMap); |
| 736 | }, |
| 737 | py::arg("n_dims"), py::arg("n_results"), |
| 738 | py::arg("context") = py::none(), |
| 739 | "Gets a minor identity map with the given number of dimensions and " |
| 740 | "results.") |
| 741 | .def_static( |
| 742 | "get_permutation", |
| 743 | [](std::vector<unsigned> permutation, |
| 744 | DefaultingPyMlirContext context) { |
| 745 | if (!isPermutation(permutation)) |
| 746 | throw py::cast_error("Invalid permutation when attempting to " |
| 747 | "create an AffineMap"); |
| 748 | MlirAffineMap affineMap = mlirAffineMapPermutationGet( |
| 749 | context->get(), permutation.size(), permutation.data()); |
| 750 | return PyAffineMap(context->getRef(), affineMap); |
| 751 | }, |
| 752 | py::arg("permutation"), py::arg("context") = py::none(), |
| 753 | "Gets an affine map that permutes its inputs.") |
Stella Laurenzo | a6e7d02 | 2021-11-28 14:08:06 -0800 | [diff] [blame] | 754 | .def( |
| 755 | "get_submap", |
| 756 | [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { |
| 757 | intptr_t numResults = mlirAffineMapGetNumResults(self); |
| 758 | for (intptr_t pos : resultPos) { |
| 759 | if (pos < 0 || pos >= numResults) |
| 760 | throw py::value_error("result position out of bounds"); |
| 761 | } |
| 762 | MlirAffineMap affineMap = mlirAffineMapGetSubMap( |
| 763 | self, resultPos.size(), resultPos.data()); |
| 764 | return PyAffineMap(self.getContext(), affineMap); |
| 765 | }, |
| 766 | py::arg("result_positions")) |
| 767 | .def( |
| 768 | "get_major_submap", |
| 769 | [](PyAffineMap &self, intptr_t nResults) { |
| 770 | if (nResults >= mlirAffineMapGetNumResults(self)) |
| 771 | throw py::value_error("number of results out of bounds"); |
| 772 | MlirAffineMap affineMap = |
| 773 | mlirAffineMapGetMajorSubMap(self, nResults); |
| 774 | return PyAffineMap(self.getContext(), affineMap); |
| 775 | }, |
| 776 | py::arg("n_results")) |
| 777 | .def( |
| 778 | "get_minor_submap", |
| 779 | [](PyAffineMap &self, intptr_t nResults) { |
| 780 | if (nResults >= mlirAffineMapGetNumResults(self)) |
| 781 | throw py::value_error("number of results out of bounds"); |
| 782 | MlirAffineMap affineMap = |
| 783 | mlirAffineMapGetMinorSubMap(self, nResults); |
| 784 | return PyAffineMap(self.getContext(), affineMap); |
| 785 | }, |
| 786 | py::arg("n_results")) |
| 787 | .def( |
| 788 | "replace", |
| 789 | [](PyAffineMap &self, PyAffineExpr &expression, |
| 790 | PyAffineExpr &replacement, intptr_t numResultDims, |
| 791 | intptr_t numResultSyms) { |
| 792 | MlirAffineMap affineMap = mlirAffineMapReplace( |
| 793 | self, expression, replacement, numResultDims, numResultSyms); |
| 794 | return PyAffineMap(self.getContext(), affineMap); |
| 795 | }, |
| 796 | py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"), |
| 797 | py::arg("n_result_syms")) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 798 | .def_property_readonly( |
| 799 | "is_permutation", |
| 800 | [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) |
| 801 | .def_property_readonly("is_projected_permutation", |
| 802 | [](PyAffineMap &self) { |
| 803 | return mlirAffineMapIsProjectedPermutation(self); |
| 804 | }) |
| 805 | .def_property_readonly( |
| 806 | "n_dims", |
| 807 | [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) |
| 808 | .def_property_readonly( |
| 809 | "n_inputs", |
| 810 | [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) |
| 811 | .def_property_readonly( |
| 812 | "n_symbols", |
| 813 | [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) |
| 814 | .def_property_readonly("results", [](PyAffineMap &self) { |
| 815 | return PyAffineMapExprList(self); |
| 816 | }); |
| 817 | PyAffineMapExprList::bind(m); |
| 818 | |
| 819 | //---------------------------------------------------------------------------- |
| 820 | // Mapping of PyIntegerSet. |
| 821 | //---------------------------------------------------------------------------- |
Stella Laurenzo | f05ff4f | 2021-08-23 20:01:07 -0700 | [diff] [blame] | 822 | py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local()) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 823 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
| 824 | &PyIntegerSet::getCapsule) |
| 825 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) |
| 826 | .def("__eq__", [](PyIntegerSet &self, |
| 827 | PyIntegerSet &other) { return self == other; }) |
| 828 | .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) |
| 829 | .def("__str__", |
| 830 | [](PyIntegerSet &self) { |
| 831 | PyPrintAccumulator printAccum; |
| 832 | mlirIntegerSetPrint(self, printAccum.getCallback(), |
| 833 | printAccum.getUserData()); |
| 834 | return printAccum.join(); |
| 835 | }) |
| 836 | .def("__repr__", |
| 837 | [](PyIntegerSet &self) { |
| 838 | PyPrintAccumulator printAccum; |
| 839 | printAccum.parts.append("IntegerSet("); |
| 840 | mlirIntegerSetPrint(self, printAccum.getCallback(), |
| 841 | printAccum.getUserData()); |
| 842 | printAccum.parts.append(")"); |
| 843 | return printAccum.join(); |
| 844 | }) |
Alex Zinenko | fc7594c | 2021-11-02 14:15:25 +0100 | [diff] [blame] | 845 | .def("__hash__", |
| 846 | [](PyIntegerSet &self) { |
| 847 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
| 848 | }) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 849 | .def_property_readonly( |
| 850 | "context", |
| 851 | [](PyIntegerSet &self) { return self.getContext().getObject(); }) |
| 852 | .def( |
| 853 | "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, |
| 854 | kDumpDocstring) |
| 855 | .def_static( |
| 856 | "get", |
| 857 | [](intptr_t numDims, intptr_t numSymbols, py::list exprs, |
| 858 | std::vector<bool> eqFlags, DefaultingPyMlirContext context) { |
| 859 | if (exprs.size() != eqFlags.size()) |
| 860 | throw py::value_error( |
| 861 | "Expected the number of constraints to match " |
| 862 | "that of equality flags"); |
| 863 | if (exprs.empty()) |
| 864 | throw py::value_error("Expected non-empty list of constraints"); |
| 865 | |
| 866 | // Copy over to a SmallVector because std::vector has a |
| 867 | // specialization for booleans that packs data and does not |
| 868 | // expose a `bool *`. |
| 869 | SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); |
| 870 | |
| 871 | SmallVector<MlirAffineExpr> affineExprs; |
| 872 | pyListToVector<PyAffineExpr>(exprs, affineExprs, |
| 873 | "attempting to create an IntegerSet"); |
| 874 | MlirIntegerSet set = mlirIntegerSetGet( |
| 875 | context->get(), numDims, numSymbols, exprs.size(), |
| 876 | affineExprs.data(), flags.data()); |
| 877 | return PyIntegerSet(context->getRef(), set); |
| 878 | }, |
| 879 | py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), |
| 880 | py::arg("eq_flags"), py::arg("context") = py::none()) |
| 881 | .def_static( |
| 882 | "get_empty", |
| 883 | [](intptr_t numDims, intptr_t numSymbols, |
| 884 | DefaultingPyMlirContext context) { |
| 885 | MlirIntegerSet set = |
| 886 | mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); |
| 887 | return PyIntegerSet(context->getRef(), set); |
| 888 | }, |
| 889 | py::arg("num_dims"), py::arg("num_symbols"), |
| 890 | py::arg("context") = py::none()) |
Stella Laurenzo | a6e7d02 | 2021-11-28 14:08:06 -0800 | [diff] [blame] | 891 | .def( |
| 892 | "get_replaced", |
| 893 | [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, |
| 894 | intptr_t numResultDims, intptr_t numResultSymbols) { |
| 895 | if (static_cast<intptr_t>(dimExprs.size()) != |
| 896 | mlirIntegerSetGetNumDims(self)) |
| 897 | throw py::value_error( |
| 898 | "Expected the number of dimension replacement expressions " |
| 899 | "to match that of dimensions"); |
| 900 | if (static_cast<intptr_t>(symbolExprs.size()) != |
| 901 | mlirIntegerSetGetNumSymbols(self)) |
| 902 | throw py::value_error( |
| 903 | "Expected the number of symbol replacement expressions " |
| 904 | "to match that of symbols"); |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 905 | |
Stella Laurenzo | a6e7d02 | 2021-11-28 14:08:06 -0800 | [diff] [blame] | 906 | SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; |
| 907 | pyListToVector<PyAffineExpr>( |
| 908 | dimExprs, dimAffineExprs, |
| 909 | "attempting to create an IntegerSet by replacing dimensions"); |
| 910 | pyListToVector<PyAffineExpr>( |
| 911 | symbolExprs, symbolAffineExprs, |
| 912 | "attempting to create an IntegerSet by replacing symbols"); |
| 913 | MlirIntegerSet set = mlirIntegerSetReplaceGet( |
| 914 | self, dimAffineExprs.data(), symbolAffineExprs.data(), |
| 915 | numResultDims, numResultSymbols); |
| 916 | return PyIntegerSet(self.getContext(), set); |
| 917 | }, |
| 918 | py::arg("dim_exprs"), py::arg("symbol_exprs"), |
| 919 | py::arg("num_result_dims"), py::arg("num_result_symbols")) |
Stella Laurenzo | 436c6c9 | 2021-03-19 11:57:01 -0700 | [diff] [blame] | 920 | .def_property_readonly("is_canonical_empty", |
| 921 | [](PyIntegerSet &self) { |
| 922 | return mlirIntegerSetIsCanonicalEmpty(self); |
| 923 | }) |
| 924 | .def_property_readonly( |
| 925 | "n_dims", |
| 926 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) |
| 927 | .def_property_readonly( |
| 928 | "n_symbols", |
| 929 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) |
| 930 | .def_property_readonly( |
| 931 | "n_inputs", |
| 932 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) |
| 933 | .def_property_readonly("n_equalities", |
| 934 | [](PyIntegerSet &self) { |
| 935 | return mlirIntegerSetGetNumEqualities(self); |
| 936 | }) |
| 937 | .def_property_readonly("n_inequalities", |
| 938 | [](PyIntegerSet &self) { |
| 939 | return mlirIntegerSetGetNumInequalities(self); |
| 940 | }) |
| 941 | .def_property_readonly("constraints", [](PyIntegerSet &self) { |
| 942 | return PyIntegerSetConstraintList(self); |
| 943 | }); |
| 944 | PyIntegerSetConstraint::bind(m); |
| 945 | PyIntegerSetConstraintList::bind(m); |
| 946 | } |