blob: 272de0d7aaaf22b31a2eaa97ee6628586258dde9 [file] [log] [blame]
Stella Laurenzo436c6c92021-03-19 11:57:01 -07001//===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IRModule.h"
10
11#include "PybindUtils.h"
12
13#include "mlir-c/AffineMap.h"
14#include "mlir-c/Bindings/Python/Interop.h"
15#include "mlir-c/IntegerSet.h"
16
17namespace py = pybind11;
18using namespace mlir;
19using namespace mlir::python;
20
21using llvm::SmallVector;
22using llvm::StringRef;
23using llvm::Twine;
24
25static const char kDumpDocstring[] =
26 R"(Dumps a debug representation of the object to stderr.)";
27
28/// Attempts to populate `result` with the content of `list` casted to the
29/// appropriate type (Python and C types are provided as template arguments).
30/// Throws errors in case of failure, using "action" to describe what the caller
31/// was attempting to do.
32template <typename PyType, typename CType>
33static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34 StringRef action) {
35 result.reserve(py::len(list));
36 for (py::handle item : list) {
37 try {
38 result.push_back(item.cast<PyType>());
39 } catch (py::cast_error &err) {
40 std::string msg = (llvm::Twine("Invalid expression when ") + action +
41 " (" + err.what() + ")")
42 .str();
43 throw py::cast_error(msg);
44 } catch (py::reference_cast_error &err) {
45 std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46 action + " (" + err.what() + ")")
47 .str();
48 throw py::cast_error(msg);
49 }
50 }
51}
52
53template <typename PermutationTy>
54static bool isPermutation(std::vector<PermutationTy> permutation) {
55 llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56 for (auto val : permutation) {
57 if (val < permutation.size()) {
58 if (seen[val])
59 return false;
60 seen[val] = true;
61 continue;
62 }
63 return false;
64 }
65 return true;
66}
67
68namespace {
69
70/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71/// and should be castable from it. Intermediate hierarchy classes can be
72/// modeled by specifying BaseTy.
73template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74class PyConcreteAffineExpr : public BaseTy {
75public:
76 // Derived classes must define statics for:
77 // IsAFunctionTy isaFunction
78 // const char *pyClassName
79 // and redefine bindDerived.
80 using ClassTy = py::class_<DerivedTy, BaseTy>;
81 using IsAFunctionTy = bool (*)(MlirAffineExpr);
82
83 PyConcreteAffineExpr() = default;
84 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85 : BaseTy(std::move(contextRef), affineExpr) {}
86 PyConcreteAffineExpr(PyAffineExpr &orig)
87 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88
89 static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90 if (!DerivedTy::isaFunction(orig)) {
91 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92 throw SetPyError(PyExc_ValueError,
93 Twine("Cannot cast affine expression to ") +
94 DerivedTy::pyClassName + " (from " + origRepr + ")");
95 }
96 return orig;
97 }
98
99 static void bind(py::module &m) {
Stella Laurenzof05ff4f2021-08-23 20:01:07 -0700100 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
Stella Laurenzoa6e7d022021-11-28 14:08:06 -0800101 cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
102 cls.def_static(
103 "isinstance",
104 [](PyAffineExpr &otherAffineExpr) -> bool {
105 return DerivedTy::isaFunction(otherAffineExpr);
106 },
107 py::arg("other"));
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700108 DerivedTy::bindDerived(cls);
109 }
110
111 /// Implemented by derived classes to add methods to the Python subclass.
112 static void bindDerived(ClassTy &m) {}
113};
114
115class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
116public:
117 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
118 static constexpr const char *pyClassName = "AffineConstantExpr";
119 using PyConcreteAffineExpr::PyConcreteAffineExpr;
120
121 static PyAffineConstantExpr get(intptr_t value,
122 DefaultingPyMlirContext context) {
123 MlirAffineExpr affineExpr =
124 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
125 return PyAffineConstantExpr(context->getRef(), affineExpr);
126 }
127
128 static void bindDerived(ClassTy &c) {
129 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
130 py::arg("context") = py::none());
131 c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
132 return mlirAffineConstantExprGetValue(self);
133 });
134 }
135};
136
137class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
138public:
139 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
140 static constexpr const char *pyClassName = "AffineDimExpr";
141 using PyConcreteAffineExpr::PyConcreteAffineExpr;
142
143 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
144 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
145 return PyAffineDimExpr(context->getRef(), affineExpr);
146 }
147
148 static void bindDerived(ClassTy &c) {
149 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
150 py::arg("context") = py::none());
151 c.def_property_readonly("position", [](PyAffineDimExpr &self) {
152 return mlirAffineDimExprGetPosition(self);
153 });
154 }
155};
156
157class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
158public:
159 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
160 static constexpr const char *pyClassName = "AffineSymbolExpr";
161 using PyConcreteAffineExpr::PyConcreteAffineExpr;
162
163 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
164 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
165 return PyAffineSymbolExpr(context->getRef(), affineExpr);
166 }
167
168 static void bindDerived(ClassTy &c) {
169 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
170 py::arg("context") = py::none());
171 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
172 return mlirAffineSymbolExprGetPosition(self);
173 });
174 }
175};
176
177class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
178public:
179 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
180 static constexpr const char *pyClassName = "AffineBinaryExpr";
181 using PyConcreteAffineExpr::PyConcreteAffineExpr;
182
183 PyAffineExpr lhs() {
184 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
185 return PyAffineExpr(getContext(), lhsExpr);
186 }
187
188 PyAffineExpr rhs() {
189 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
190 return PyAffineExpr(getContext(), rhsExpr);
191 }
192
193 static void bindDerived(ClassTy &c) {
194 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
195 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
196 }
197};
198
199class PyAffineAddExpr
200 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
201public:
202 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
203 static constexpr const char *pyClassName = "AffineAddExpr";
204 using PyConcreteAffineExpr::PyConcreteAffineExpr;
205
206 static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
207 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
208 return PyAffineAddExpr(lhs.getContext(), expr);
209 }
210
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100211 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
212 MlirAffineExpr expr = mlirAffineAddExprGet(
213 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
214 return PyAffineAddExpr(lhs.getContext(), expr);
215 }
216
217 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
218 MlirAffineExpr expr = mlirAffineAddExprGet(
219 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
220 return PyAffineAddExpr(rhs.getContext(), expr);
221 }
222
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700223 static void bindDerived(ClassTy &c) {
224 c.def_static("get", &PyAffineAddExpr::get);
225 }
226};
227
228class PyAffineMulExpr
229 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
230public:
231 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
232 static constexpr const char *pyClassName = "AffineMulExpr";
233 using PyConcreteAffineExpr::PyConcreteAffineExpr;
234
235 static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
236 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
237 return PyAffineMulExpr(lhs.getContext(), expr);
238 }
239
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100240 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
241 MlirAffineExpr expr = mlirAffineMulExprGet(
242 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
243 return PyAffineMulExpr(lhs.getContext(), expr);
244 }
245
246 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
247 MlirAffineExpr expr = mlirAffineMulExprGet(
248 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
249 return PyAffineMulExpr(rhs.getContext(), expr);
250 }
251
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700252 static void bindDerived(ClassTy &c) {
253 c.def_static("get", &PyAffineMulExpr::get);
254 }
255};
256
257class PyAffineModExpr
258 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
259public:
260 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
261 static constexpr const char *pyClassName = "AffineModExpr";
262 using PyConcreteAffineExpr::PyConcreteAffineExpr;
263
264 static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
265 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
266 return PyAffineModExpr(lhs.getContext(), expr);
267 }
268
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100269 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
270 MlirAffineExpr expr = mlirAffineModExprGet(
271 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
272 return PyAffineModExpr(lhs.getContext(), expr);
273 }
274
275 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
276 MlirAffineExpr expr = mlirAffineModExprGet(
277 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
278 return PyAffineModExpr(rhs.getContext(), expr);
279 }
280
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700281 static void bindDerived(ClassTy &c) {
282 c.def_static("get", &PyAffineModExpr::get);
283 }
284};
285
286class PyAffineFloorDivExpr
287 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
288public:
289 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
290 static constexpr const char *pyClassName = "AffineFloorDivExpr";
291 using PyConcreteAffineExpr::PyConcreteAffineExpr;
292
293 static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
294 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
295 return PyAffineFloorDivExpr(lhs.getContext(), expr);
296 }
297
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100298 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
299 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
300 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
301 return PyAffineFloorDivExpr(lhs.getContext(), expr);
302 }
303
304 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
305 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
306 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
307 return PyAffineFloorDivExpr(rhs.getContext(), expr);
308 }
309
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700310 static void bindDerived(ClassTy &c) {
311 c.def_static("get", &PyAffineFloorDivExpr::get);
312 }
313};
314
315class PyAffineCeilDivExpr
316 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
317public:
318 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
319 static constexpr const char *pyClassName = "AffineCeilDivExpr";
320 using PyConcreteAffineExpr::PyConcreteAffineExpr;
321
322 static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
323 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
324 return PyAffineCeilDivExpr(lhs.getContext(), expr);
325 }
326
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100327 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
328 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
329 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
330 return PyAffineCeilDivExpr(lhs.getContext(), expr);
331 }
332
333 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
334 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
335 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
336 return PyAffineCeilDivExpr(rhs.getContext(), expr);
337 }
338
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700339 static void bindDerived(ClassTy &c) {
340 c.def_static("get", &PyAffineCeilDivExpr::get);
341 }
342};
343
344} // namespace
345
346bool PyAffineExpr::operator==(const PyAffineExpr &other) {
347 return mlirAffineExprEqual(affineExpr, other.affineExpr);
348}
349
350py::object PyAffineExpr::getCapsule() {
351 return py::reinterpret_steal<py::object>(
352 mlirPythonAffineExprToCapsule(*this));
353}
354
355PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
356 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
357 if (mlirAffineExprIsNull(rawAffineExpr))
358 throw py::error_already_set();
359 return PyAffineExpr(
360 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
361 rawAffineExpr);
362}
363
364//------------------------------------------------------------------------------
365// PyAffineMap and utilities.
366//------------------------------------------------------------------------------
367namespace {
368
369/// A list of expressions contained in an affine map. Internally these are
370/// stored as a consecutive array leading to inexpensive random access. Both
371/// the map and the expression are owned by the context so we need not bother
372/// with lifetime extension.
373class PyAffineMapExprList
374 : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
375public:
376 static constexpr const char *pyClassName = "AffineExprList";
377
378 PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
379 intptr_t length = -1, intptr_t step = 1)
380 : Sliceable(startIndex,
381 length == -1 ? mlirAffineMapGetNumResults(map) : length,
382 step),
383 affineMap(map) {}
384
385 intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
386
387 PyAffineExpr getElement(intptr_t pos) {
388 return PyAffineExpr(affineMap.getContext(),
389 mlirAffineMapGetResult(affineMap, pos));
390 }
391
392 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
393 intptr_t step) {
394 return PyAffineMapExprList(affineMap, startIndex, length, step);
395 }
396
397private:
398 PyAffineMap affineMap;
399};
400} // end namespace
401
402bool PyAffineMap::operator==(const PyAffineMap &other) {
403 return mlirAffineMapEqual(affineMap, other.affineMap);
404}
405
406py::object PyAffineMap::getCapsule() {
407 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
408}
409
410PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
411 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
412 if (mlirAffineMapIsNull(rawAffineMap))
413 throw py::error_already_set();
414 return PyAffineMap(
415 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
416 rawAffineMap);
417}
418
419//------------------------------------------------------------------------------
420// PyIntegerSet and utilities.
421//------------------------------------------------------------------------------
422namespace {
423
424class PyIntegerSetConstraint {
425public:
426 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
427
428 PyAffineExpr getExpr() {
429 return PyAffineExpr(set.getContext(),
430 mlirIntegerSetGetConstraint(set, pos));
431 }
432
433 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
434
435 static void bind(py::module &m) {
Stella Laurenzof05ff4f2021-08-23 20:01:07 -0700436 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
437 py::module_local())
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700438 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
439 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
440 }
441
442private:
443 PyIntegerSet set;
444 intptr_t pos;
445};
446
447class PyIntegerSetConstraintList
448 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
449public:
450 static constexpr const char *pyClassName = "IntegerSetConstraintList";
451
452 PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
453 intptr_t length = -1, intptr_t step = 1)
454 : Sliceable(startIndex,
455 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
456 step),
457 set(set) {}
458
459 intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
460
461 PyIntegerSetConstraint getElement(intptr_t pos) {
462 return PyIntegerSetConstraint(set, pos);
463 }
464
465 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
466 intptr_t step) {
467 return PyIntegerSetConstraintList(set, startIndex, length, step);
468 }
469
470private:
471 PyIntegerSet set;
472};
473} // namespace
474
475bool PyIntegerSet::operator==(const PyIntegerSet &other) {
476 return mlirIntegerSetEqual(integerSet, other.integerSet);
477}
478
479py::object PyIntegerSet::getCapsule() {
480 return py::reinterpret_steal<py::object>(
481 mlirPythonIntegerSetToCapsule(*this));
482}
483
484PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
485 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
486 if (mlirIntegerSetIsNull(rawIntegerSet))
487 throw py::error_already_set();
488 return PyIntegerSet(
489 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
490 rawIntegerSet);
491}
492
493void mlir::python::populateIRAffine(py::module &m) {
494 //----------------------------------------------------------------------------
495 // Mapping of PyAffineExpr and derived classes.
496 //----------------------------------------------------------------------------
Stella Laurenzof05ff4f2021-08-23 20:01:07 -0700497 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700498 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
499 &PyAffineExpr::getCapsule)
500 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100501 .def("__add__", &PyAffineAddExpr::get)
502 .def("__add__", &PyAffineAddExpr::getRHSConstant)
503 .def("__radd__", &PyAffineAddExpr::getRHSConstant)
504 .def("__mul__", &PyAffineMulExpr::get)
505 .def("__mul__", &PyAffineMulExpr::getRHSConstant)
506 .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
507 .def("__mod__", &PyAffineModExpr::get)
508 .def("__mod__", &PyAffineModExpr::getRHSConstant)
509 .def("__rmod__",
510 [](PyAffineExpr &self, intptr_t other) {
511 return PyAffineModExpr::get(
512 PyAffineConstantExpr::get(other, *self.getContext().get()),
513 self);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700514 })
515 .def("__sub__",
516 [](PyAffineExpr &self, PyAffineExpr &other) {
517 auto negOne =
518 PyAffineConstantExpr::get(-1, *self.getContext().get());
519 return PyAffineAddExpr::get(self,
520 PyAffineMulExpr::get(negOne, other));
521 })
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100522 .def("__sub__",
523 [](PyAffineExpr &self, intptr_t other) {
524 return PyAffineAddExpr::get(
525 self,
526 PyAffineConstantExpr::get(-other, *self.getContext().get()));
527 })
528 .def("__rsub__",
529 [](PyAffineExpr &self, intptr_t other) {
530 return PyAffineAddExpr::getLHSConstant(
531 other, PyAffineMulExpr::getLHSConstant(-1, self));
532 })
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700533 .def("__eq__", [](PyAffineExpr &self,
534 PyAffineExpr &other) { return self == other; })
535 .def("__eq__",
536 [](PyAffineExpr &self, py::object &other) { return false; })
537 .def("__str__",
538 [](PyAffineExpr &self) {
539 PyPrintAccumulator printAccum;
540 mlirAffineExprPrint(self, printAccum.getCallback(),
541 printAccum.getUserData());
542 return printAccum.join();
543 })
544 .def("__repr__",
545 [](PyAffineExpr &self) {
546 PyPrintAccumulator printAccum;
547 printAccum.parts.append("AffineExpr(");
548 mlirAffineExprPrint(self, printAccum.getCallback(),
549 printAccum.getUserData());
550 printAccum.parts.append(")");
551 return printAccum.join();
552 })
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100553 .def("__hash__",
554 [](PyAffineExpr &self) {
555 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
556 })
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700557 .def_property_readonly(
558 "context",
559 [](PyAffineExpr &self) { return self.getContext().getObject(); })
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100560 .def("compose",
561 [](PyAffineExpr &self, PyAffineMap &other) {
562 return PyAffineExpr(self.getContext(),
563 mlirAffineExprCompose(self, other));
564 })
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700565 .def_static(
566 "get_add", &PyAffineAddExpr::get,
567 "Gets an affine expression containing a sum of two expressions.")
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100568 .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
569 "Gets an affine expression containing a sum of a constant "
570 "and another expression.")
571 .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
572 "Gets an affine expression containing a sum of an expression "
573 "and a constant.")
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700574 .def_static(
575 "get_mul", &PyAffineMulExpr::get,
576 "Gets an affine expression containing a product of two expressions.")
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100577 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
578 "Gets an affine expression containing a product of a "
579 "constant and another expression.")
580 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
581 "Gets an affine expression containing a product of an "
582 "expression and a constant.")
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700583 .def_static("get_mod", &PyAffineModExpr::get,
584 "Gets an affine expression containing the modulo of dividing "
585 "one expression by another.")
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100586 .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
587 "Gets a semi-affine expression containing the modulo of "
588 "dividing a constant by an expression.")
589 .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
590 "Gets an affine expression containing the module of dividing"
591 "an expression by a constant.")
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700592 .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
593 "Gets an affine expression containing the rounded-down "
594 "result of dividing one expression by another.")
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100595 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
596 "Gets a semi-affine expression containing the rounded-down "
597 "result of dividing a constant by an expression.")
598 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
599 "Gets an affine expression containing the rounded-down "
600 "result of dividing an expression by a constant.")
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700601 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
602 "Gets an affine expression containing the rounded-up result "
603 "of dividing one expression by another.")
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100604 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
605 "Gets a semi-affine expression containing the rounded-up "
606 "result of dividing a constant by an expression.")
607 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
608 "Gets an affine expression containing the rounded-up result "
609 "of dividing an expression by a constant.")
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700610 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
611 py::arg("context") = py::none(),
612 "Gets a constant affine expression with the given value.")
613 .def_static(
614 "get_dim", &PyAffineDimExpr::get, py::arg("position"),
615 py::arg("context") = py::none(),
616 "Gets an affine expression of a dimension at the given position.")
617 .def_static(
618 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
619 py::arg("context") = py::none(),
620 "Gets an affine expression of a symbol at the given position.")
621 .def(
622 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
623 kDumpDocstring);
624 PyAffineConstantExpr::bind(m);
625 PyAffineDimExpr::bind(m);
626 PyAffineSymbolExpr::bind(m);
627 PyAffineBinaryExpr::bind(m);
628 PyAffineAddExpr::bind(m);
629 PyAffineMulExpr::bind(m);
630 PyAffineModExpr::bind(m);
631 PyAffineFloorDivExpr::bind(m);
632 PyAffineCeilDivExpr::bind(m);
633
634 //----------------------------------------------------------------------------
635 // Mapping of PyAffineMap.
636 //----------------------------------------------------------------------------
Stella Laurenzof05ff4f2021-08-23 20:01:07 -0700637 py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700638 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
639 &PyAffineMap::getCapsule)
640 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
641 .def("__eq__",
642 [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
643 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
644 .def("__str__",
645 [](PyAffineMap &self) {
646 PyPrintAccumulator printAccum;
647 mlirAffineMapPrint(self, printAccum.getCallback(),
648 printAccum.getUserData());
649 return printAccum.join();
650 })
651 .def("__repr__",
652 [](PyAffineMap &self) {
653 PyPrintAccumulator printAccum;
654 printAccum.parts.append("AffineMap(");
655 mlirAffineMapPrint(self, printAccum.getCallback(),
656 printAccum.getUserData());
657 printAccum.parts.append(")");
658 return printAccum.join();
659 })
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100660 .def("__hash__",
661 [](PyAffineMap &self) {
662 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
663 })
Nicolas Vasilache335d2df2021-03-31 09:33:08 +0000664 .def_static("compress_unused_symbols",
665 [](py::list affineMaps, DefaultingPyMlirContext context) {
666 SmallVector<MlirAffineMap> maps;
667 pyListToVector<PyAffineMap, MlirAffineMap>(
668 affineMaps, maps, "attempting to create an AffineMap");
669 std::vector<MlirAffineMap> compressed(affineMaps.size());
670 auto populate = [](void *result, intptr_t idx,
671 MlirAffineMap m) {
672 static_cast<MlirAffineMap *>(result)[idx] = (m);
673 };
674 mlirAffineMapCompressUnusedSymbols(
675 maps.data(), maps.size(), compressed.data(), populate);
676 std::vector<PyAffineMap> res;
Mehdi Aminie2f16be2021-10-19 17:13:54 +0000677 res.reserve(compressed.size());
Nicolas Vasilache335d2df2021-03-31 09:33:08 +0000678 for (auto m : compressed)
679 res.push_back(PyAffineMap(context->getRef(), m));
680 return res;
681 })
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700682 .def_property_readonly(
683 "context",
684 [](PyAffineMap &self) { return self.getContext().getObject(); },
685 "Context that owns the Affine Map")
686 .def(
687 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
688 kDumpDocstring)
689 .def_static(
690 "get",
691 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
692 DefaultingPyMlirContext context) {
693 SmallVector<MlirAffineExpr> affineExprs;
694 pyListToVector<PyAffineExpr, MlirAffineExpr>(
695 exprs, affineExprs, "attempting to create an AffineMap");
696 MlirAffineMap map =
697 mlirAffineMapGet(context->get(), dimCount, symbolCount,
698 affineExprs.size(), affineExprs.data());
699 return PyAffineMap(context->getRef(), map);
700 },
701 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
702 py::arg("context") = py::none(),
703 "Gets a map with the given expressions as results.")
704 .def_static(
705 "get_constant",
706 [](intptr_t value, DefaultingPyMlirContext context) {
707 MlirAffineMap affineMap =
708 mlirAffineMapConstantGet(context->get(), value);
709 return PyAffineMap(context->getRef(), affineMap);
710 },
711 py::arg("value"), py::arg("context") = py::none(),
712 "Gets an affine map with a single constant result")
713 .def_static(
714 "get_empty",
715 [](DefaultingPyMlirContext context) {
716 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
717 return PyAffineMap(context->getRef(), affineMap);
718 },
719 py::arg("context") = py::none(), "Gets an empty affine map.")
720 .def_static(
721 "get_identity",
722 [](intptr_t nDims, DefaultingPyMlirContext context) {
723 MlirAffineMap affineMap =
724 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
725 return PyAffineMap(context->getRef(), affineMap);
726 },
727 py::arg("n_dims"), py::arg("context") = py::none(),
728 "Gets an identity map with the given number of dimensions.")
729 .def_static(
730 "get_minor_identity",
731 [](intptr_t nDims, intptr_t nResults,
732 DefaultingPyMlirContext context) {
733 MlirAffineMap affineMap =
734 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
735 return PyAffineMap(context->getRef(), affineMap);
736 },
737 py::arg("n_dims"), py::arg("n_results"),
738 py::arg("context") = py::none(),
739 "Gets a minor identity map with the given number of dimensions and "
740 "results.")
741 .def_static(
742 "get_permutation",
743 [](std::vector<unsigned> permutation,
744 DefaultingPyMlirContext context) {
745 if (!isPermutation(permutation))
746 throw py::cast_error("Invalid permutation when attempting to "
747 "create an AffineMap");
748 MlirAffineMap affineMap = mlirAffineMapPermutationGet(
749 context->get(), permutation.size(), permutation.data());
750 return PyAffineMap(context->getRef(), affineMap);
751 },
752 py::arg("permutation"), py::arg("context") = py::none(),
753 "Gets an affine map that permutes its inputs.")
Stella Laurenzoa6e7d022021-11-28 14:08:06 -0800754 .def(
755 "get_submap",
756 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
757 intptr_t numResults = mlirAffineMapGetNumResults(self);
758 for (intptr_t pos : resultPos) {
759 if (pos < 0 || pos >= numResults)
760 throw py::value_error("result position out of bounds");
761 }
762 MlirAffineMap affineMap = mlirAffineMapGetSubMap(
763 self, resultPos.size(), resultPos.data());
764 return PyAffineMap(self.getContext(), affineMap);
765 },
766 py::arg("result_positions"))
767 .def(
768 "get_major_submap",
769 [](PyAffineMap &self, intptr_t nResults) {
770 if (nResults >= mlirAffineMapGetNumResults(self))
771 throw py::value_error("number of results out of bounds");
772 MlirAffineMap affineMap =
773 mlirAffineMapGetMajorSubMap(self, nResults);
774 return PyAffineMap(self.getContext(), affineMap);
775 },
776 py::arg("n_results"))
777 .def(
778 "get_minor_submap",
779 [](PyAffineMap &self, intptr_t nResults) {
780 if (nResults >= mlirAffineMapGetNumResults(self))
781 throw py::value_error("number of results out of bounds");
782 MlirAffineMap affineMap =
783 mlirAffineMapGetMinorSubMap(self, nResults);
784 return PyAffineMap(self.getContext(), affineMap);
785 },
786 py::arg("n_results"))
787 .def(
788 "replace",
789 [](PyAffineMap &self, PyAffineExpr &expression,
790 PyAffineExpr &replacement, intptr_t numResultDims,
791 intptr_t numResultSyms) {
792 MlirAffineMap affineMap = mlirAffineMapReplace(
793 self, expression, replacement, numResultDims, numResultSyms);
794 return PyAffineMap(self.getContext(), affineMap);
795 },
796 py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
797 py::arg("n_result_syms"))
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700798 .def_property_readonly(
799 "is_permutation",
800 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
801 .def_property_readonly("is_projected_permutation",
802 [](PyAffineMap &self) {
803 return mlirAffineMapIsProjectedPermutation(self);
804 })
805 .def_property_readonly(
806 "n_dims",
807 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
808 .def_property_readonly(
809 "n_inputs",
810 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
811 .def_property_readonly(
812 "n_symbols",
813 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
814 .def_property_readonly("results", [](PyAffineMap &self) {
815 return PyAffineMapExprList(self);
816 });
817 PyAffineMapExprList::bind(m);
818
819 //----------------------------------------------------------------------------
820 // Mapping of PyIntegerSet.
821 //----------------------------------------------------------------------------
Stella Laurenzof05ff4f2021-08-23 20:01:07 -0700822 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700823 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
824 &PyIntegerSet::getCapsule)
825 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
826 .def("__eq__", [](PyIntegerSet &self,
827 PyIntegerSet &other) { return self == other; })
828 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
829 .def("__str__",
830 [](PyIntegerSet &self) {
831 PyPrintAccumulator printAccum;
832 mlirIntegerSetPrint(self, printAccum.getCallback(),
833 printAccum.getUserData());
834 return printAccum.join();
835 })
836 .def("__repr__",
837 [](PyIntegerSet &self) {
838 PyPrintAccumulator printAccum;
839 printAccum.parts.append("IntegerSet(");
840 mlirIntegerSetPrint(self, printAccum.getCallback(),
841 printAccum.getUserData());
842 printAccum.parts.append(")");
843 return printAccum.join();
844 })
Alex Zinenkofc7594c2021-11-02 14:15:25 +0100845 .def("__hash__",
846 [](PyIntegerSet &self) {
847 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
848 })
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700849 .def_property_readonly(
850 "context",
851 [](PyIntegerSet &self) { return self.getContext().getObject(); })
852 .def(
853 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
854 kDumpDocstring)
855 .def_static(
856 "get",
857 [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
858 std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
859 if (exprs.size() != eqFlags.size())
860 throw py::value_error(
861 "Expected the number of constraints to match "
862 "that of equality flags");
863 if (exprs.empty())
864 throw py::value_error("Expected non-empty list of constraints");
865
866 // Copy over to a SmallVector because std::vector has a
867 // specialization for booleans that packs data and does not
868 // expose a `bool *`.
869 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
870
871 SmallVector<MlirAffineExpr> affineExprs;
872 pyListToVector<PyAffineExpr>(exprs, affineExprs,
873 "attempting to create an IntegerSet");
874 MlirIntegerSet set = mlirIntegerSetGet(
875 context->get(), numDims, numSymbols, exprs.size(),
876 affineExprs.data(), flags.data());
877 return PyIntegerSet(context->getRef(), set);
878 },
879 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
880 py::arg("eq_flags"), py::arg("context") = py::none())
881 .def_static(
882 "get_empty",
883 [](intptr_t numDims, intptr_t numSymbols,
884 DefaultingPyMlirContext context) {
885 MlirIntegerSet set =
886 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
887 return PyIntegerSet(context->getRef(), set);
888 },
889 py::arg("num_dims"), py::arg("num_symbols"),
890 py::arg("context") = py::none())
Stella Laurenzoa6e7d022021-11-28 14:08:06 -0800891 .def(
892 "get_replaced",
893 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
894 intptr_t numResultDims, intptr_t numResultSymbols) {
895 if (static_cast<intptr_t>(dimExprs.size()) !=
896 mlirIntegerSetGetNumDims(self))
897 throw py::value_error(
898 "Expected the number of dimension replacement expressions "
899 "to match that of dimensions");
900 if (static_cast<intptr_t>(symbolExprs.size()) !=
901 mlirIntegerSetGetNumSymbols(self))
902 throw py::value_error(
903 "Expected the number of symbol replacement expressions "
904 "to match that of symbols");
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700905
Stella Laurenzoa6e7d022021-11-28 14:08:06 -0800906 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
907 pyListToVector<PyAffineExpr>(
908 dimExprs, dimAffineExprs,
909 "attempting to create an IntegerSet by replacing dimensions");
910 pyListToVector<PyAffineExpr>(
911 symbolExprs, symbolAffineExprs,
912 "attempting to create an IntegerSet by replacing symbols");
913 MlirIntegerSet set = mlirIntegerSetReplaceGet(
914 self, dimAffineExprs.data(), symbolAffineExprs.data(),
915 numResultDims, numResultSymbols);
916 return PyIntegerSet(self.getContext(), set);
917 },
918 py::arg("dim_exprs"), py::arg("symbol_exprs"),
919 py::arg("num_result_dims"), py::arg("num_result_symbols"))
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700920 .def_property_readonly("is_canonical_empty",
921 [](PyIntegerSet &self) {
922 return mlirIntegerSetIsCanonicalEmpty(self);
923 })
924 .def_property_readonly(
925 "n_dims",
926 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
927 .def_property_readonly(
928 "n_symbols",
929 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
930 .def_property_readonly(
931 "n_inputs",
932 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
933 .def_property_readonly("n_equalities",
934 [](PyIntegerSet &self) {
935 return mlirIntegerSetGetNumEqualities(self);
936 })
937 .def_property_readonly("n_inequalities",
938 [](PyIntegerSet &self) {
939 return mlirIntegerSetGetNumInequalities(self);
940 })
941 .def_property_readonly("constraints", [](PyIntegerSet &self) {
942 return PyIntegerSetConstraintList(self);
943 });
944 PyIntegerSetConstraint::bind(m);
945 PyIntegerSetConstraintList::bind(m);
946}