blob: 17b3b34a2ea3069c8ffdd2ee7284a90870afafc8 [file] [log] [blame]
Stella Laurenzo436c6c92021-03-19 11:57:01 -07001//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IRModule.h"
10
11#include "PybindUtils.h"
12
13#include "mlir-c/BuiltinAttributes.h"
14#include "mlir-c/BuiltinTypes.h"
15
16namespace py = pybind11;
17using namespace mlir;
18using namespace mlir::python;
19
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -070020using llvm::None;
21using llvm::Optional;
Stella Laurenzo436c6c92021-03-19 11:57:01 -070022using llvm::SmallVector;
Stella Laurenzo436c6c92021-03-19 11:57:01 -070023using llvm::Twine;
24
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -070025//------------------------------------------------------------------------------
26// Docstrings (trivial, non-duplicated docstrings are included inline).
27//------------------------------------------------------------------------------
28
29static const char kDenseElementsAttrGetDocstring[] =
30 R"(Gets a DenseElementsAttr from a Python buffer or array.
31
32When `type` is not provided, then some limited type inferencing is done based
33on the buffer format. Support presently exists for 8/16/32/64 signed and
34unsigned integers and float16/float32/float64. DenseElementsAttrs of these
35types can also be converted back to a corresponding buffer.
36
37For conversions outside of these types, a `type=` must be explicitly provided
38and the buffer contents must be bit-castable to the MLIR internal
39representation:
40
41 * Integer types (except for i1): the buffer must be byte aligned to the
42 next byte boundary.
43 * Floating point types: Must be bit-castable to the given floating point
44 size.
45 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
46 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
47 this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
48
49If a single element buffer is passed (or for i1, a single byte with value 0
50or 255), then a splat will be created.
51
52Args:
53 array: The array or buffer to convert.
54 signless: If inferring an appropriate MLIR type, use signless types for
55 integers (defaults True).
56 type: Skips inference of the MLIR element type and uses this instead. The
57 storage size must be consistent with the actual contents of the buffer.
58 shape: Overrides the shape of the buffer when constructing the MLIR
59 shaped type. This is needed when the physical and logical shape differ (as
60 for i1).
61 context: Explicit context, if not from context manager.
62
63Returns:
64 DenseElementsAttr on success.
65
66Raises:
67 ValueError: If the type of the buffer or array cannot be matched to an MLIR
68 type or if the buffer does not meet expectations.
69)";
70
Stella Laurenzo436c6c92021-03-19 11:57:01 -070071namespace {
72
73static MlirStringRef toMlirStringRef(const std::string &s) {
74 return mlirStringRefCreate(s.data(), s.size());
75}
76
Stella Laurenzo436c6c92021-03-19 11:57:01 -070077class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
78public:
79 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
80 static constexpr const char *pyClassName = "AffineMapAttr";
81 using PyConcreteAttribute::PyConcreteAttribute;
82
83 static void bindDerived(ClassTy &c) {
84 c.def_static(
85 "get",
86 [](PyAffineMap &affineMap) {
87 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
88 return PyAffineMapAttribute(affineMap.getContext(), attr);
89 },
90 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
91 }
92};
93
Alex Zinenkoed9e52f2021-10-04 11:38:20 +020094template <typename T>
95static T pyTryCast(py::handle object) {
96 try {
97 return object.cast<T>();
98 } catch (py::cast_error &err) {
99 std::string msg =
100 std::string(
101 "Invalid attribute when attempting to create an ArrayAttribute (") +
102 err.what() + ")";
103 throw py::cast_error(msg);
104 } catch (py::reference_cast_error &err) {
105 std::string msg = std::string("Invalid attribute (None?) when attempting "
106 "to create an ArrayAttribute (") +
107 err.what() + ")";
108 throw py::cast_error(msg);
109 }
110}
111
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700112class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
113public:
114 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
115 static constexpr const char *pyClassName = "ArrayAttr";
116 using PyConcreteAttribute::PyConcreteAttribute;
117
118 class PyArrayAttributeIterator {
119 public:
120 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
121
122 PyArrayAttributeIterator &dunderIter() { return *this; }
123
124 PyAttribute dunderNext() {
125 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
126 throw py::stop_iteration();
127 }
128 return PyAttribute(attr.getContext(),
129 mlirArrayAttrGetElement(attr.get(), nextIndex++));
130 }
131
132 static void bind(py::module &m) {
Stella Laurenzof05ff4f2021-08-23 20:01:07 -0700133 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
134 py::module_local())
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700135 .def("__iter__", &PyArrayAttributeIterator::dunderIter)
136 .def("__next__", &PyArrayAttributeIterator::dunderNext);
137 }
138
139 private:
140 PyAttribute attr;
141 int nextIndex = 0;
142 };
143
Alex Zinenkoed9e52f2021-10-04 11:38:20 +0200144 PyAttribute getItem(intptr_t i) {
145 return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
146 }
147
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700148 static void bindDerived(ClassTy &c) {
149 c.def_static(
150 "get",
151 [](py::list attributes, DefaultingPyMlirContext context) {
152 SmallVector<MlirAttribute> mlirAttributes;
153 mlirAttributes.reserve(py::len(attributes));
154 for (auto attribute : attributes) {
Alex Zinenkoed9e52f2021-10-04 11:38:20 +0200155 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700156 }
157 MlirAttribute attr = mlirArrayAttrGet(
158 context->get(), mlirAttributes.size(), mlirAttributes.data());
159 return PyArrayAttribute(context->getRef(), attr);
160 },
161 py::arg("attributes"), py::arg("context") = py::none(),
162 "Gets a uniqued Array attribute");
163 c.def("__getitem__",
164 [](PyArrayAttribute &arr, intptr_t i) {
165 if (i >= mlirArrayAttrGetNumElements(arr))
166 throw py::index_error("ArrayAttribute index out of range");
Alex Zinenkoed9e52f2021-10-04 11:38:20 +0200167 return arr.getItem(i);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700168 })
169 .def("__len__",
170 [](const PyArrayAttribute &arr) {
171 return mlirArrayAttrGetNumElements(arr);
172 })
173 .def("__iter__", [](const PyArrayAttribute &arr) {
174 return PyArrayAttributeIterator(arr);
175 });
Alex Zinenkoed9e52f2021-10-04 11:38:20 +0200176 c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
177 std::vector<MlirAttribute> attributes;
178 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
179 attributes.reserve(numOldElements + py::len(extras));
180 for (intptr_t i = 0; i < numOldElements; ++i)
181 attributes.push_back(arr.getItem(i));
182 for (py::handle attr : extras)
183 attributes.push_back(pyTryCast<PyAttribute>(attr));
184 MlirAttribute arrayAttr = mlirArrayAttrGet(
185 arr.getContext()->get(), attributes.size(), attributes.data());
186 return PyArrayAttribute(arr.getContext(), arrayAttr);
187 });
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700188 }
189};
190
191/// Float Point Attribute subclass - FloatAttr.
192class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
193public:
194 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
195 static constexpr const char *pyClassName = "FloatAttr";
196 using PyConcreteAttribute::PyConcreteAttribute;
197
198 static void bindDerived(ClassTy &c) {
199 c.def_static(
200 "get",
201 [](PyType &type, double value, DefaultingPyLocation loc) {
202 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
203 // TODO: Rework error reporting once diagnostic engine is exposed
204 // in C API.
205 if (mlirAttributeIsNull(attr)) {
206 throw SetPyError(PyExc_ValueError,
207 Twine("invalid '") +
208 py::repr(py::cast(type)).cast<std::string>() +
209 "' and expected floating point type.");
210 }
211 return PyFloatAttribute(type.getContext(), attr);
212 },
213 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
214 "Gets an uniqued float point attribute associated to a type");
215 c.def_static(
216 "get_f32",
217 [](double value, DefaultingPyMlirContext context) {
218 MlirAttribute attr = mlirFloatAttrDoubleGet(
219 context->get(), mlirF32TypeGet(context->get()), value);
220 return PyFloatAttribute(context->getRef(), attr);
221 },
222 py::arg("value"), py::arg("context") = py::none(),
223 "Gets an uniqued float point attribute associated to a f32 type");
224 c.def_static(
225 "get_f64",
226 [](double value, DefaultingPyMlirContext context) {
227 MlirAttribute attr = mlirFloatAttrDoubleGet(
228 context->get(), mlirF64TypeGet(context->get()), value);
229 return PyFloatAttribute(context->getRef(), attr);
230 },
231 py::arg("value"), py::arg("context") = py::none(),
232 "Gets an uniqued float point attribute associated to a f64 type");
233 c.def_property_readonly(
234 "value",
235 [](PyFloatAttribute &self) {
236 return mlirFloatAttrGetValueDouble(self);
237 },
238 "Returns the value of the float point attribute");
239 }
240};
241
242/// Integer Attribute subclass - IntegerAttr.
243class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
244public:
245 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
246 static constexpr const char *pyClassName = "IntegerAttr";
247 using PyConcreteAttribute::PyConcreteAttribute;
248
249 static void bindDerived(ClassTy &c) {
250 c.def_static(
251 "get",
252 [](PyType &type, int64_t value) {
253 MlirAttribute attr = mlirIntegerAttrGet(type, value);
254 return PyIntegerAttribute(type.getContext(), attr);
255 },
256 py::arg("type"), py::arg("value"),
257 "Gets an uniqued integer attribute associated to a type");
258 c.def_property_readonly(
259 "value",
260 [](PyIntegerAttribute &self) {
261 return mlirIntegerAttrGetValueInt(self);
262 },
263 "Returns the value of the integer attribute");
264 }
265};
266
267/// Bool Attribute subclass - BoolAttr.
268class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
269public:
270 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
271 static constexpr const char *pyClassName = "BoolAttr";
272 using PyConcreteAttribute::PyConcreteAttribute;
273
274 static void bindDerived(ClassTy &c) {
275 c.def_static(
276 "get",
277 [](bool value, DefaultingPyMlirContext context) {
278 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
279 return PyBoolAttribute(context->getRef(), attr);
280 },
281 py::arg("value"), py::arg("context") = py::none(),
282 "Gets an uniqued bool attribute");
283 c.def_property_readonly(
284 "value",
285 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
286 "Returns the value of the bool attribute");
287 }
288};
289
290class PyFlatSymbolRefAttribute
291 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
292public:
293 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
294 static constexpr const char *pyClassName = "FlatSymbolRefAttr";
295 using PyConcreteAttribute::PyConcreteAttribute;
296
297 static void bindDerived(ClassTy &c) {
298 c.def_static(
299 "get",
300 [](std::string value, DefaultingPyMlirContext context) {
301 MlirAttribute attr =
302 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
303 return PyFlatSymbolRefAttribute(context->getRef(), attr);
304 },
305 py::arg("value"), py::arg("context") = py::none(),
306 "Gets a uniqued FlatSymbolRef attribute");
307 c.def_property_readonly(
308 "value",
309 [](PyFlatSymbolRefAttribute &self) {
310 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
311 return py::str(stringRef.data, stringRef.length);
312 },
313 "Returns the value of the FlatSymbolRef attribute as a string");
314 }
315};
316
317class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
318public:
319 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
320 static constexpr const char *pyClassName = "StringAttr";
321 using PyConcreteAttribute::PyConcreteAttribute;
322
323 static void bindDerived(ClassTy &c) {
324 c.def_static(
325 "get",
326 [](std::string value, DefaultingPyMlirContext context) {
327 MlirAttribute attr =
328 mlirStringAttrGet(context->get(), toMlirStringRef(value));
329 return PyStringAttribute(context->getRef(), attr);
330 },
331 py::arg("value"), py::arg("context") = py::none(),
332 "Gets a uniqued string attribute");
333 c.def_static(
334 "get_typed",
335 [](PyType &type, std::string value) {
336 MlirAttribute attr =
337 mlirStringAttrTypedGet(type, toMlirStringRef(value));
338 return PyStringAttribute(type.getContext(), attr);
339 },
Stella Laurenzoa6e7d022021-11-28 14:08:06 -0800340 py::arg("type"), py::arg("value"),
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700341 "Gets a uniqued string attribute associated to a type");
342 c.def_property_readonly(
343 "value",
344 [](PyStringAttribute &self) {
345 MlirStringRef stringRef = mlirStringAttrGetValue(self);
346 return py::str(stringRef.data, stringRef.length);
347 },
348 "Returns the value of the string attribute");
349 }
350};
351
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700352// TODO: Support construction of string elements.
353class PyDenseElementsAttribute
354 : public PyConcreteAttribute<PyDenseElementsAttribute> {
355public:
356 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
357 static constexpr const char *pyClassName = "DenseElementsAttr";
358 using PyConcreteAttribute::PyConcreteAttribute;
359
360 static PyDenseElementsAttribute
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700361 getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
362 Optional<std::vector<int64_t>> explicitShape,
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700363 DefaultingPyMlirContext contextWrapper) {
364 // Request a contiguous view. In exotic cases, this will cause a copy.
365 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
366 Py_buffer *view = new Py_buffer();
367 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
368 delete view;
369 throw py::error_already_set();
370 }
371 py::buffer_info arrayInfo(view);
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700372 SmallVector<int64_t> shape;
373 if (explicitShape) {
374 shape.append(explicitShape->begin(), explicitShape->end());
375 } else {
376 shape.append(arrayInfo.shape.begin(),
377 arrayInfo.shape.begin() + arrayInfo.ndim);
378 }
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700379
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700380 MlirAttribute encodingAttr = mlirAttributeGetNull();
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700381 MlirContext context = contextWrapper->get();
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700382
383 // Detect format codes that are suitable for bulk loading. This includes
384 // all byte aligned integer and floating point types up to 8 bytes.
385 // Notably, this excludes, bool (which needs to be bit-packed) and
386 // other exotics which do not have a direct representation in the buffer
387 // protocol (i.e. complex, etc).
388 Optional<MlirType> bulkLoadElementType;
389 if (explicitType) {
390 bulkLoadElementType = *explicitType;
391 } else if (arrayInfo.format == "f") {
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700392 // f32
393 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700394 bulkLoadElementType = mlirF32TypeGet(context);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700395 } else if (arrayInfo.format == "d") {
396 // f64
397 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700398 bulkLoadElementType = mlirF64TypeGet(context);
399 } else if (arrayInfo.format == "e") {
400 // f16
401 assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
402 bulkLoadElementType = mlirF16TypeGet(context);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700403 } else if (isSignedIntegerFormat(arrayInfo.format)) {
404 if (arrayInfo.itemsize == 4) {
405 // i32
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700406 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
407 : mlirIntegerTypeSignedGet(context, 32);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700408 } else if (arrayInfo.itemsize == 8) {
409 // i64
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700410 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
411 : mlirIntegerTypeSignedGet(context, 64);
412 } else if (arrayInfo.itemsize == 1) {
413 // i8
414 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
415 : mlirIntegerTypeSignedGet(context, 8);
416 } else if (arrayInfo.itemsize == 2) {
417 // i16
418 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
419 : mlirIntegerTypeSignedGet(context, 16);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700420 }
421 } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
422 if (arrayInfo.itemsize == 4) {
423 // unsigned i32
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700424 bulkLoadElementType = signless
425 ? mlirIntegerTypeGet(context, 32)
426 : mlirIntegerTypeUnsignedGet(context, 32);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700427 } else if (arrayInfo.itemsize == 8) {
428 // unsigned i64
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700429 bulkLoadElementType = signless
430 ? mlirIntegerTypeGet(context, 64)
431 : mlirIntegerTypeUnsignedGet(context, 64);
432 } else if (arrayInfo.itemsize == 1) {
433 // i8
434 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
435 : mlirIntegerTypeUnsignedGet(context, 8);
436 } else if (arrayInfo.itemsize == 2) {
437 // i16
438 bulkLoadElementType = signless
439 ? mlirIntegerTypeGet(context, 16)
440 : mlirIntegerTypeUnsignedGet(context, 16);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700441 }
442 }
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700443 if (bulkLoadElementType) {
444 auto shapedType = mlirRankedTensorTypeGet(
445 shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
446 size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
447 MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
448 shapedType, rawBufferSize, arrayInfo.ptr);
449 if (mlirAttributeIsNull(attr)) {
450 throw std::invalid_argument(
451 "DenseElementsAttr could not be constructed from the given buffer. "
452 "This may mean that the Python buffer layout does not match that "
453 "MLIR expected layout and is a bug.");
454 }
455 return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
456 }
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700457
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700458 throw std::invalid_argument(
459 std::string("unimplemented array format conversion from format: ") +
460 arrayInfo.format);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700461 }
462
463 static PyDenseElementsAttribute getSplat(PyType shapedType,
464 PyAttribute &elementAttr) {
465 auto contextWrapper =
466 PyMlirContext::forContext(mlirTypeGetContext(shapedType));
467 if (!mlirAttributeIsAInteger(elementAttr) &&
468 !mlirAttributeIsAFloat(elementAttr)) {
469 std::string message = "Illegal element type for DenseElementsAttr: ";
470 message.append(py::repr(py::cast(elementAttr)));
471 throw SetPyError(PyExc_ValueError, message);
472 }
473 if (!mlirTypeIsAShaped(shapedType) ||
474 !mlirShapedTypeHasStaticShape(shapedType)) {
475 std::string message =
476 "Expected a static ShapedType for the shaped_type parameter: ";
477 message.append(py::repr(py::cast(shapedType)));
478 throw SetPyError(PyExc_ValueError, message);
479 }
480 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
481 MlirType attrType = mlirAttributeGetType(elementAttr);
482 if (!mlirTypeEqual(shapedElementType, attrType)) {
483 std::string message =
484 "Shaped element type and attribute type must be equal: shaped=";
485 message.append(py::repr(py::cast(shapedType)));
486 message.append(", element=");
487 message.append(py::repr(py::cast(elementAttr)));
488 throw SetPyError(PyExc_ValueError, message);
489 }
490
491 MlirAttribute elements =
492 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
493 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
494 }
495
496 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
497
498 py::buffer_info accessBuffer() {
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700499 if (mlirDenseElementsAttrIsSplat(*this)) {
Stella Laurenzoc5f445d2021-10-07 11:47:05 -0700500 // TODO: Currently crashes the program.
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700501 // Reported as https://github.com/pybind/pybind11/issues/3336
Stella Laurenzoc5f445d2021-10-07 11:47:05 -0700502 throw std::invalid_argument(
503 "unsupported data type for conversion to Python buffer");
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700504 }
505
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700506 MlirType shapedType = mlirAttributeGetType(*this);
507 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700508 std::string format;
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700509
510 if (mlirTypeIsAF32(elementType)) {
511 // f32
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700512 return bufferInfo<float>(shapedType);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700513 } else if (mlirTypeIsAF64(elementType)) {
514 // f64
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700515 return bufferInfo<double>(shapedType);
516 } else if (mlirTypeIsAF16(elementType)) {
517 // f16
518 return bufferInfo<uint16_t>(shapedType, "e");
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700519 } else if (mlirTypeIsAInteger(elementType) &&
520 mlirIntegerTypeGetWidth(elementType) == 32) {
521 if (mlirIntegerTypeIsSignless(elementType) ||
522 mlirIntegerTypeIsSigned(elementType)) {
523 // i32
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700524 return bufferInfo<int32_t>(shapedType);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700525 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
526 // unsigned i32
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700527 return bufferInfo<uint32_t>(shapedType);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700528 }
529 } else if (mlirTypeIsAInteger(elementType) &&
530 mlirIntegerTypeGetWidth(elementType) == 64) {
531 if (mlirIntegerTypeIsSignless(elementType) ||
532 mlirIntegerTypeIsSigned(elementType)) {
533 // i64
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700534 return bufferInfo<int64_t>(shapedType);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700535 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
536 // unsigned i64
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700537 return bufferInfo<uint64_t>(shapedType);
538 }
539 } else if (mlirTypeIsAInteger(elementType) &&
540 mlirIntegerTypeGetWidth(elementType) == 8) {
541 if (mlirIntegerTypeIsSignless(elementType) ||
542 mlirIntegerTypeIsSigned(elementType)) {
543 // i8
544 return bufferInfo<int8_t>(shapedType);
545 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
546 // unsigned i8
547 return bufferInfo<uint8_t>(shapedType);
548 }
549 } else if (mlirTypeIsAInteger(elementType) &&
550 mlirIntegerTypeGetWidth(elementType) == 16) {
551 if (mlirIntegerTypeIsSignless(elementType) ||
552 mlirIntegerTypeIsSigned(elementType)) {
553 // i16
554 return bufferInfo<int16_t>(shapedType);
555 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
556 // unsigned i16
557 return bufferInfo<uint16_t>(shapedType);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700558 }
559 }
560
Stella Laurenzoc5f445d2021-10-07 11:47:05 -0700561 // TODO: Currently crashes the program.
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700562 // Reported as https://github.com/pybind/pybind11/issues/3336
Stella Laurenzoc5f445d2021-10-07 11:47:05 -0700563 throw std::invalid_argument(
564 "unsupported data type for conversion to Python buffer");
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700565 }
566
567 static void bindDerived(ClassTy &c) {
568 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
569 .def_static("get", PyDenseElementsAttribute::getFromBuffer,
570 py::arg("array"), py::arg("signless") = true,
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700571 py::arg("type") = py::none(), py::arg("shape") = py::none(),
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700572 py::arg("context") = py::none(),
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700573 kDenseElementsAttrGetDocstring)
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700574 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
575 py::arg("shaped_type"), py::arg("element_attr"),
576 "Gets a DenseElementsAttr where all values are the same")
577 .def_property_readonly("is_splat",
578 [](PyDenseElementsAttribute &self) -> bool {
579 return mlirDenseElementsAttrIsSplat(self);
580 })
581 .def_buffer(&PyDenseElementsAttribute::accessBuffer);
582 }
583
584private:
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700585 static bool isUnsignedIntegerFormat(const std::string &format) {
586 if (format.empty())
587 return false;
588 char code = format[0];
589 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
590 code == 'Q';
591 }
592
593 static bool isSignedIntegerFormat(const std::string &format) {
594 if (format.empty())
595 return false;
596 char code = format[0];
597 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
598 code == 'q';
599 }
600
601 template <typename Type>
602 py::buffer_info bufferInfo(MlirType shapedType,
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700603 const char *explicitFormat = nullptr) {
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700604 intptr_t rank = mlirShapedTypeGetRank(shapedType);
605 // Prepare the data for the buffer_info.
606 // Buffer is configured for read-only access below.
607 Type *data = static_cast<Type *>(
608 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
609 // Prepare the shape for the buffer_info.
610 SmallVector<intptr_t, 4> shape;
611 for (intptr_t i = 0; i < rank; ++i)
612 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
613 // Prepare the strides for the buffer_info.
614 SmallVector<intptr_t, 4> strides;
615 intptr_t strideFactor = 1;
616 for (intptr_t i = 1; i < rank; ++i) {
617 strideFactor = 1;
618 for (intptr_t j = i; j < rank; ++j) {
619 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
620 }
621 strides.push_back(sizeof(Type) * strideFactor);
622 }
623 strides.push_back(sizeof(Type));
Stella Laurenzo5d6d30e2021-10-06 18:41:22 -0700624 std::string format;
625 if (explicitFormat) {
626 format = explicitFormat;
627 } else {
628 format = py::format_descriptor<Type>::format();
629 }
630 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
631 /*readonly=*/true);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700632 }
633}; // namespace
634
635/// Refinement of the PyDenseElementsAttribute for attributes containing integer
636/// (and boolean) values. Supports element access.
637class PyDenseIntElementsAttribute
638 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
639 PyDenseElementsAttribute> {
640public:
641 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
642 static constexpr const char *pyClassName = "DenseIntElementsAttr";
643 using PyConcreteAttribute::PyConcreteAttribute;
644
645 /// Returns the element at the given linear position. Asserts if the index is
646 /// out of range.
647 py::int_ dunderGetItem(intptr_t pos) {
648 if (pos < 0 || pos >= dunderLen()) {
649 throw SetPyError(PyExc_IndexError,
650 "attempt to access out of bounds element");
651 }
652
653 MlirType type = mlirAttributeGetType(*this);
654 type = mlirShapedTypeGetElementType(type);
655 assert(mlirTypeIsAInteger(type) &&
656 "expected integer element type in dense int elements attribute");
657 // Dispatch element extraction to an appropriate C function based on the
658 // elemental type of the attribute. py::int_ is implicitly constructible
659 // from any C++ integral type and handles bitwidth correctly.
660 // TODO: consider caching the type properties in the constructor to avoid
661 // querying them on each element access.
662 unsigned width = mlirIntegerTypeGetWidth(type);
663 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
664 if (isUnsigned) {
665 if (width == 1) {
666 return mlirDenseElementsAttrGetBoolValue(*this, pos);
667 }
668 if (width == 32) {
669 return mlirDenseElementsAttrGetUInt32Value(*this, pos);
670 }
671 if (width == 64) {
672 return mlirDenseElementsAttrGetUInt64Value(*this, pos);
673 }
674 } else {
675 if (width == 1) {
676 return mlirDenseElementsAttrGetBoolValue(*this, pos);
677 }
678 if (width == 32) {
679 return mlirDenseElementsAttrGetInt32Value(*this, pos);
680 }
681 if (width == 64) {
682 return mlirDenseElementsAttrGetInt64Value(*this, pos);
683 }
684 }
685 throw SetPyError(PyExc_TypeError, "Unsupported integer type");
686 }
687
688 static void bindDerived(ClassTy &c) {
689 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
690 }
691};
692
693class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
694public:
695 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
696 static constexpr const char *pyClassName = "DictAttr";
697 using PyConcreteAttribute::PyConcreteAttribute;
698
699 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
700
Adrian Kuegel9fb10862021-10-29 15:11:09 +0200701 bool dunderContains(const std::string &name) {
702 return !mlirAttributeIsNull(
703 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
704 }
705
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700706 static void bindDerived(ClassTy &c) {
Adrian Kuegel9fb10862021-10-29 15:11:09 +0200707 c.def("__contains__", &PyDictAttribute::dunderContains);
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700708 c.def("__len__", &PyDictAttribute::dunderLen);
709 c.def_static(
710 "get",
711 [](py::dict attributes, DefaultingPyMlirContext context) {
712 SmallVector<MlirNamedAttribute> mlirNamedAttributes;
713 mlirNamedAttributes.reserve(attributes.size());
714 for (auto &it : attributes) {
715 auto &mlir_attr = it.second.cast<PyAttribute &>();
716 auto name = it.first.cast<std::string>();
717 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
718 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
719 toMlirStringRef(name)),
720 mlir_attr));
721 }
722 MlirAttribute attr =
723 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
724 mlirNamedAttributes.data());
725 return PyDictAttribute(context->getRef(), attr);
726 },
Alex Zinenkoed9e52f2021-10-04 11:38:20 +0200727 py::arg("value") = py::dict(), py::arg("context") = py::none(),
Stella Laurenzo436c6c92021-03-19 11:57:01 -0700728 "Gets an uniqued dict attribute");
729 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
730 MlirAttribute attr =
731 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
732 if (mlirAttributeIsNull(attr)) {
733 throw SetPyError(PyExc_KeyError,
734 "attempt to access a non-existent attribute");
735 }
736 return PyAttribute(self.getContext(), attr);
737 });
738 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
739 if (index < 0 || index >= self.dunderLen()) {
740 throw SetPyError(PyExc_IndexError,
741 "attempt to access out of bounds attribute");
742 }
743 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
744 return PyNamedAttribute(
745 namedAttr.attribute,
746 std::string(mlirIdentifierStr(namedAttr.name).data));
747 });
748 }
749};
750
751/// Refinement of PyDenseElementsAttribute for attributes containing
752/// floating-point values. Supports element access.
753class PyDenseFPElementsAttribute
754 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
755 PyDenseElementsAttribute> {
756public:
757 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
758 static constexpr const char *pyClassName = "DenseFPElementsAttr";
759 using PyConcreteAttribute::PyConcreteAttribute;
760
761 py::float_ dunderGetItem(intptr_t pos) {
762 if (pos < 0 || pos >= dunderLen()) {
763 throw SetPyError(PyExc_IndexError,
764 "attempt to access out of bounds element");
765 }
766
767 MlirType type = mlirAttributeGetType(*this);
768 type = mlirShapedTypeGetElementType(type);
769 // Dispatch element extraction to an appropriate C function based on the
770 // elemental type of the attribute. py::float_ is implicitly constructible
771 // from float and double.
772 // TODO: consider caching the type properties in the constructor to avoid
773 // querying them on each element access.
774 if (mlirTypeIsAF32(type)) {
775 return mlirDenseElementsAttrGetFloatValue(*this, pos);
776 }
777 if (mlirTypeIsAF64(type)) {
778 return mlirDenseElementsAttrGetDoubleValue(*this, pos);
779 }
780 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
781 }
782
783 static void bindDerived(ClassTy &c) {
784 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
785 }
786};
787
788class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
789public:
790 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
791 static constexpr const char *pyClassName = "TypeAttr";
792 using PyConcreteAttribute::PyConcreteAttribute;
793
794 static void bindDerived(ClassTy &c) {
795 c.def_static(
796 "get",
797 [](PyType value, DefaultingPyMlirContext context) {
798 MlirAttribute attr = mlirTypeAttrGet(value.get());
799 return PyTypeAttribute(context->getRef(), attr);
800 },
801 py::arg("value"), py::arg("context") = py::none(),
802 "Gets a uniqued Type attribute");
803 c.def_property_readonly("value", [](PyTypeAttribute &self) {
804 return PyType(self.getContext()->getRef(),
805 mlirTypeAttrGetValue(self.get()));
806 });
807 }
808};
809
810/// Unit Attribute subclass. Unit attributes don't have values.
811class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
812public:
813 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
814 static constexpr const char *pyClassName = "UnitAttr";
815 using PyConcreteAttribute::PyConcreteAttribute;
816
817 static void bindDerived(ClassTy &c) {
818 c.def_static(
819 "get",
820 [](DefaultingPyMlirContext context) {
821 return PyUnitAttribute(context->getRef(),
822 mlirUnitAttrGet(context->get()));
823 },
824 py::arg("context") = py::none(), "Create a Unit attribute.");
825 }
826};
827
828} // namespace
829
830void mlir::python::populateIRAttributes(py::module &m) {
831 PyAffineMapAttribute::bind(m);
832 PyArrayAttribute::bind(m);
833 PyArrayAttribute::PyArrayAttributeIterator::bind(m);
834 PyBoolAttribute::bind(m);
835 PyDenseElementsAttribute::bind(m);
836 PyDenseFPElementsAttribute::bind(m);
837 PyDenseIntElementsAttribute::bind(m);
838 PyDictAttribute::bind(m);
839 PyFlatSymbolRefAttribute::bind(m);
840 PyFloatAttribute::bind(m);
841 PyIntegerAttribute::bind(m);
842 PyStringAttribute::bind(m);
843 PyTypeAttribute::bind(m);
844 PyUnitAttribute::bind(m);
845}