[mlir] Extend C and Python API to support bulk loading of DenseElementsAttr.
* This already half existed in terms of reading the raw buffer backing a DenseElementsAttr.
* Documented the precise expectations of the buffer layout.
* Extended the Python API to support construction from bitcasted buffers, allowing construction of all primitive element types (even those that lack a compatible representation in Python).
* Specifically, the Python API can now load all integer types at all bit widths and all floating point types (f16, f32, f64, bf16).
Differential Revision: https://reviews.llvm.org/D111284
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 2ff75ce..47f73ec 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -17,9 +17,57 @@
using namespace mlir;
using namespace mlir::python;
+using llvm::None;
+using llvm::Optional;
using llvm::SmallVector;
using llvm::Twine;
+//------------------------------------------------------------------------------
+// Docstrings (trivial, non-duplicated docstrings are included inline).
+//------------------------------------------------------------------------------
+
+static const char kDenseElementsAttrGetDocstring[] =
+ R"(Gets a DenseElementsAttr from a Python buffer or array.
+
+When `type` is not provided, then some limited type inferencing is done based
+on the buffer format. Support presently exists for 8/16/32/64 signed and
+unsigned integers and float16/float32/float64. DenseElementsAttrs of these
+types can also be converted back to a corresponding buffer.
+
+For conversions outside of these types, a `type=` must be explicitly provided
+and the buffer contents must be bit-castable to the MLIR internal
+representation:
+
+ * Integer types (except for i1): the buffer must be byte aligned to the
+ next byte boundary.
+ * Floating point types: Must be bit-castable to the given floating point
+ size.
+ * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
+ row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
+ this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
+
+If a single element buffer is passed (or for i1, a single byte with value 0
+or 255), then a splat will be created.
+
+Args:
+ array: The array or buffer to convert.
+ signless: If inferring an appropriate MLIR type, use signless types for
+ integers (defaults True).
+ type: Skips inference of the MLIR element type and uses this instead. The
+ storage size must be consistent with the actual contents of the buffer.
+ shape: Overrides the shape of the buffer when constructing the MLIR
+ shaped type. This is needed when the physical and logical shape differ (as
+ for i1).
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the buffer or array cannot be matched to an MLIR
+ type or if the buffer does not meet expectations.
+)";
+
namespace {
static MlirStringRef toMlirStringRef(const std::string &s) {
@@ -301,7 +349,6 @@
}
};
-// TODO: Support construction of bool elements.
// TODO: Support construction of string elements.
class PyDenseElementsAttribute
: public PyConcreteAttribute<PyDenseElementsAttribute> {
@@ -311,7 +358,8 @@
using PyConcreteAttribute::PyConcreteAttribute;
static PyDenseElementsAttribute
- getFromBuffer(py::buffer array, bool signless,
+ getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
+ Optional<std::vector<int64_t>> explicitShape,
DefaultingPyMlirContext contextWrapper) {
// Request a contiguous view. In exotic cases, this will cause a copy.
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
@@ -321,69 +369,95 @@
throw py::error_already_set();
}
py::buffer_info arrayInfo(view);
+ SmallVector<int64_t> shape;
+ if (explicitShape) {
+ shape.append(explicitShape->begin(), explicitShape->end());
+ } else {
+ shape.append(arrayInfo.shape.begin(),
+ arrayInfo.shape.begin() + arrayInfo.ndim);
+ }
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirContext context = contextWrapper->get();
- // Switch on the types that can be bulk loaded between the Python and
- // MLIR-C APIs.
- // See: https://docs.python.org/3/library/struct.html#format-characters
- if (arrayInfo.format == "f") {
+
+ // Detect format codes that are suitable for bulk loading. This includes
+ // all byte aligned integer and floating point types up to 8 bytes.
+ // Notably, this excludes, bool (which needs to be bit-packed) and
+ // other exotics which do not have a direct representation in the buffer
+ // protocol (i.e. complex, etc).
+ Optional<MlirType> bulkLoadElementType;
+ if (explicitType) {
+ bulkLoadElementType = *explicitType;
+ } else if (arrayInfo.format == "f") {
// f32
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
- return PyDenseElementsAttribute(
- contextWrapper->getRef(),
- bulkLoad(context, mlirDenseElementsAttrFloatGet,
- mlirF32TypeGet(context), arrayInfo));
+ bulkLoadElementType = mlirF32TypeGet(context);
} else if (arrayInfo.format == "d") {
// f64
assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
- return PyDenseElementsAttribute(
- contextWrapper->getRef(),
- bulkLoad(context, mlirDenseElementsAttrDoubleGet,
- mlirF64TypeGet(context), arrayInfo));
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (arrayInfo.format == "e") {
+ // f16
+ assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
} else if (isSignedIntegerFormat(arrayInfo.format)) {
if (arrayInfo.itemsize == 4) {
// i32
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrInt32Get,
- elementType, arrayInfo));
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
} else if (arrayInfo.itemsize == 8) {
// i64
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrInt64Get,
- elementType, arrayInfo));
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (arrayInfo.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (arrayInfo.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(arrayInfo.format)) {
if (arrayInfo.itemsize == 4) {
// unsigned i32
- MlirType elementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrUInt32Get,
- elementType, arrayInfo));
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
} else if (arrayInfo.itemsize == 8) {
// unsigned i64
- MlirType elementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrUInt64Get,
- elementType, arrayInfo));
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (arrayInfo.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (arrayInfo.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
}
}
+ if (bulkLoadElementType) {
+ auto shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
+ size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
+ MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
+ shapedType, rawBufferSize, arrayInfo.ptr);
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseElementsAttr could not be constructed from the given buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
+ }
+ return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+ }
- // TODO: Fall back to string-based get.
- std::string message = "unimplemented array format conversion from format: ";
- message.append(arrayInfo.format);
- throw SetPyError(PyExc_ValueError, message);
+ throw std::invalid_argument(
+ std::string("unimplemented array format conversion from format: ") +
+ arrayInfo.format);
}
static PyDenseElementsAttribute getSplat(PyType shapedType,
@@ -422,47 +496,82 @@
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
py::buffer_info accessBuffer() {
+ if (mlirDenseElementsAttrIsSplat(*this)) {
+ // TODO: Raise an exception.
+ // Reported as https://github.com/pybind/pybind11/issues/3336
+ return py::buffer_info();
+ }
+
MlirType shapedType = mlirAttributeGetType(*this);
MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+ std::string format;
if (mlirTypeIsAF32(elementType)) {
// f32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
+ return bufferInfo<float>(shapedType);
} else if (mlirTypeIsAF64(elementType)) {
// f64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
+ return bufferInfo<double>(shapedType);
+ } else if (mlirTypeIsAF16(elementType)) {
+ // f16
+ return bufferInfo<uint16_t>(shapedType, "e");
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 32) {
if (mlirIntegerTypeIsSignless(elementType) ||
mlirIntegerTypeIsSigned(elementType)) {
// i32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
+ return bufferInfo<int32_t>(shapedType);
} else if (mlirIntegerTypeIsUnsigned(elementType)) {
// unsigned i32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
+ return bufferInfo<uint32_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 64) {
if (mlirIntegerTypeIsSignless(elementType) ||
mlirIntegerTypeIsSigned(elementType)) {
// i64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
+ return bufferInfo<int64_t>(shapedType);
} else if (mlirIntegerTypeIsUnsigned(elementType)) {
// unsigned i64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
+ return bufferInfo<uint64_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 8) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i8
+ return bufferInfo<int8_t>(shapedType);
+ } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i8
+ return bufferInfo<uint8_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 16) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i16
+ return bufferInfo<int16_t>(shapedType);
+ } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i16
+ return bufferInfo<uint16_t>(shapedType);
}
}
- std::string message = "unimplemented array format.";
- throw SetPyError(PyExc_ValueError, message);
+ // TODO: Currently crashes the program. Just returning an empty buffer
+ // for now.
+ // Reported as https://github.com/pybind/pybind11/issues/3336
+ // throw std::invalid_argument(
+ // "unsupported data type for conversion to Python buffer");
+ return py::buffer_info();
}
static void bindDerived(ClassTy &c) {
c.def("__len__", &PyDenseElementsAttribute::dunderLen)
.def_static("get", PyDenseElementsAttribute::getFromBuffer,
py::arg("array"), py::arg("signless") = true,
+ py::arg("type") = py::none(), py::arg("shape") = py::none(),
py::arg("context") = py::none(),
- "Gets from a buffer or ndarray")
+ kDenseElementsAttrGetDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
@@ -474,21 +583,6 @@
}
private:
- template <typename ElementTy>
- static MlirAttribute
- bulkLoad(MlirContext context,
- MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
- MlirType mlirElementType, py::buffer_info &arrayInfo) {
- SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
- arrayInfo.shape.begin() + arrayInfo.ndim);
- MlirAttribute encodingAttr = mlirAttributeGetNull();
- auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
- mlirElementType, encodingAttr);
- intptr_t numElements = arrayInfo.size;
- const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
- return ctor(shapedType, numElements, contents);
- }
-
static bool isUnsignedIntegerFormat(const std::string &format) {
if (format.empty())
return false;
@@ -507,7 +601,7 @@
template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
- Type (*value)(MlirAttribute, intptr_t)) {
+ const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the data for the buffer_info.
// Buffer is configured for read-only access below.
@@ -528,9 +622,14 @@
strides.push_back(sizeof(Type) * strideFactor);
}
strides.push_back(sizeof(Type));
- return py::buffer_info(data, sizeof(Type),
- py::format_descriptor<Type>::format(), rank, shape,
- strides, /*readonly=*/true);
+ std::string format;
+ if (explicitFormat) {
+ format = explicitFormat;
+ } else {
+ format = py::format_descriptor<Type>::format();
+ }
+ return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
+ /*readonly=*/true);
}
}; // namespace