[mlir] Extend C and Python API to support bulk loading of DenseElementsAttr.

* This already half existed in terms of reading the raw buffer backing a DenseElementsAttr.
* Documented the precise expectations of the buffer layout.
* Extended the Python API to support construction from bitcasted buffers, allowing construction of all primitive element types (even those that lack a compatible representation in Python).
* Specifically, the Python API can now load all integer types at all bit widths and all floating point types (f16, f32, f64, bf16).

Differential Revision: https://reviews.llvm.org/D111284
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 2ff75ce..47f73ec 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -17,9 +17,57 @@
 using namespace mlir;
 using namespace mlir::python;
 
+using llvm::None;
+using llvm::Optional;
 using llvm::SmallVector;
 using llvm::Twine;
 
+//------------------------------------------------------------------------------
+// Docstrings (trivial, non-duplicated docstrings are included inline).
+//------------------------------------------------------------------------------
+
+static const char kDenseElementsAttrGetDocstring[] =
+    R"(Gets a DenseElementsAttr from a Python buffer or array.
+
+When `type` is not provided, then some limited type inferencing is done based
+on the buffer format. Support presently exists for 8/16/32/64 signed and
+unsigned integers and float16/float32/float64. DenseElementsAttrs of these
+types can also be converted back to a corresponding buffer.
+
+For conversions outside of these types, a `type=` must be explicitly provided
+and the buffer contents must be bit-castable to the MLIR internal
+representation:
+
+  * Integer types (except for i1): the buffer must be byte aligned to the
+    next byte boundary.
+  * Floating point types: Must be bit-castable to the given floating point
+    size.
+  * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
+    row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
+    this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
+
+If a single element buffer is passed (or for i1, a single byte with value 0
+or 255), then a splat will be created.
+
+Args:
+  array: The array or buffer to convert.
+  signless: If inferring an appropriate MLIR type, use signless types for
+    integers (defaults True).
+  type: Skips inference of the MLIR element type and uses this instead. The
+    storage size must be consistent with the actual contents of the buffer.
+  shape: Overrides the shape of the buffer when constructing the MLIR
+    shaped type. This is needed when the physical and logical shape differ (as
+    for i1).
+  context: Explicit context, if not from context manager.
+
+Returns:
+  DenseElementsAttr on success.
+
+Raises:
+  ValueError: If the type of the buffer or array cannot be matched to an MLIR
+    type or if the buffer does not meet expectations.
+)";
+
 namespace {
 
 static MlirStringRef toMlirStringRef(const std::string &s) {
@@ -301,7 +349,6 @@
   }
 };
 
-// TODO: Support construction of bool elements.
 // TODO: Support construction of string elements.
 class PyDenseElementsAttribute
     : public PyConcreteAttribute<PyDenseElementsAttribute> {
@@ -311,7 +358,8 @@
   using PyConcreteAttribute::PyConcreteAttribute;
 
   static PyDenseElementsAttribute
-  getFromBuffer(py::buffer array, bool signless,
+  getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
+                Optional<std::vector<int64_t>> explicitShape,
                 DefaultingPyMlirContext contextWrapper) {
     // Request a contiguous view. In exotic cases, this will cause a copy.
     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
@@ -321,69 +369,95 @@
       throw py::error_already_set();
     }
     py::buffer_info arrayInfo(view);
+    SmallVector<int64_t> shape;
+    if (explicitShape) {
+      shape.append(explicitShape->begin(), explicitShape->end());
+    } else {
+      shape.append(arrayInfo.shape.begin(),
+                   arrayInfo.shape.begin() + arrayInfo.ndim);
+    }
 
+    MlirAttribute encodingAttr = mlirAttributeGetNull();
     MlirContext context = contextWrapper->get();
-    // Switch on the types that can be bulk loaded between the Python and
-    // MLIR-C APIs.
-    // See: https://docs.python.org/3/library/struct.html#format-characters
-    if (arrayInfo.format == "f") {
+
+    // Detect format codes that are suitable for bulk loading. This includes
+    // all byte aligned integer and floating point types up to 8 bytes.
+    // Notably, this excludes, bool (which needs to be bit-packed) and
+    // other exotics which do not have a direct representation in the buffer
+    // protocol (i.e. complex, etc).
+    Optional<MlirType> bulkLoadElementType;
+    if (explicitType) {
+      bulkLoadElementType = *explicitType;
+    } else if (arrayInfo.format == "f") {
       // f32
       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
-      return PyDenseElementsAttribute(
-          contextWrapper->getRef(),
-          bulkLoad(context, mlirDenseElementsAttrFloatGet,
-                   mlirF32TypeGet(context), arrayInfo));
+      bulkLoadElementType = mlirF32TypeGet(context);
     } else if (arrayInfo.format == "d") {
       // f64
       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
-      return PyDenseElementsAttribute(
-          contextWrapper->getRef(),
-          bulkLoad(context, mlirDenseElementsAttrDoubleGet,
-                   mlirF64TypeGet(context), arrayInfo));
+      bulkLoadElementType = mlirF64TypeGet(context);
+    } else if (arrayInfo.format == "e") {
+      // f16
+      assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
+      bulkLoadElementType = mlirF16TypeGet(context);
     } else if (isSignedIntegerFormat(arrayInfo.format)) {
       if (arrayInfo.itemsize == 4) {
         // i32
-        MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
-                                        : mlirIntegerTypeSignedGet(context, 32);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrInt32Get,
-                                                 elementType, arrayInfo));
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+                                       : mlirIntegerTypeSignedGet(context, 32);
       } else if (arrayInfo.itemsize == 8) {
         // i64
-        MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
-                                        : mlirIntegerTypeSignedGet(context, 64);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrInt64Get,
-                                                 elementType, arrayInfo));
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+                                       : mlirIntegerTypeSignedGet(context, 64);
+      } else if (arrayInfo.itemsize == 1) {
+        // i8
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                       : mlirIntegerTypeSignedGet(context, 8);
+      } else if (arrayInfo.itemsize == 2) {
+        // i16
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+                                       : mlirIntegerTypeSignedGet(context, 16);
       }
     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
       if (arrayInfo.itemsize == 4) {
         // unsigned i32
-        MlirType elementType = signless
-                                   ? mlirIntegerTypeGet(context, 32)
-                                   : mlirIntegerTypeUnsignedGet(context, 32);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrUInt32Get,
-                                                 elementType, arrayInfo));
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 32)
+                                  : mlirIntegerTypeUnsignedGet(context, 32);
       } else if (arrayInfo.itemsize == 8) {
         // unsigned i64
-        MlirType elementType = signless
-                                   ? mlirIntegerTypeGet(context, 64)
-                                   : mlirIntegerTypeUnsignedGet(context, 64);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrUInt64Get,
-                                                 elementType, arrayInfo));
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 64)
+                                  : mlirIntegerTypeUnsignedGet(context, 64);
+      } else if (arrayInfo.itemsize == 1) {
+        // i8
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                       : mlirIntegerTypeUnsignedGet(context, 8);
+      } else if (arrayInfo.itemsize == 2) {
+        // i16
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 16)
+                                  : mlirIntegerTypeUnsignedGet(context, 16);
       }
     }
+    if (bulkLoadElementType) {
+      auto shapedType = mlirRankedTensorTypeGet(
+          shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
+      size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
+      MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
+          shapedType, rawBufferSize, arrayInfo.ptr);
+      if (mlirAttributeIsNull(attr)) {
+        throw std::invalid_argument(
+            "DenseElementsAttr could not be constructed from the given buffer. "
+            "This may mean that the Python buffer layout does not match that "
+            "MLIR expected layout and is a bug.");
+      }
+      return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+    }
 
-    // TODO: Fall back to string-based get.
-    std::string message = "unimplemented array format conversion from format: ";
-    message.append(arrayInfo.format);
-    throw SetPyError(PyExc_ValueError, message);
+    throw std::invalid_argument(
+        std::string("unimplemented array format conversion from format: ") +
+        arrayInfo.format);
   }
 
   static PyDenseElementsAttribute getSplat(PyType shapedType,
@@ -422,47 +496,82 @@
   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
 
   py::buffer_info accessBuffer() {
+    if (mlirDenseElementsAttrIsSplat(*this)) {
+      // TODO: Raise an exception.
+      // Reported as https://github.com/pybind/pybind11/issues/3336
+      return py::buffer_info();
+    }
+
     MlirType shapedType = mlirAttributeGetType(*this);
     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+    std::string format;
 
     if (mlirTypeIsAF32(elementType)) {
       // f32
-      return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
+      return bufferInfo<float>(shapedType);
     } else if (mlirTypeIsAF64(elementType)) {
       // f64
-      return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
+      return bufferInfo<double>(shapedType);
+    } else if (mlirTypeIsAF16(elementType)) {
+      // f16
+      return bufferInfo<uint16_t>(shapedType, "e");
     } else if (mlirTypeIsAInteger(elementType) &&
                mlirIntegerTypeGetWidth(elementType) == 32) {
       if (mlirIntegerTypeIsSignless(elementType) ||
           mlirIntegerTypeIsSigned(elementType)) {
         // i32
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
+        return bufferInfo<int32_t>(shapedType);
       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
         // unsigned i32
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
+        return bufferInfo<uint32_t>(shapedType);
       }
     } else if (mlirTypeIsAInteger(elementType) &&
                mlirIntegerTypeGetWidth(elementType) == 64) {
       if (mlirIntegerTypeIsSignless(elementType) ||
           mlirIntegerTypeIsSigned(elementType)) {
         // i64
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
+        return bufferInfo<int64_t>(shapedType);
       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
         // unsigned i64
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
+        return bufferInfo<uint64_t>(shapedType);
+      }
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 8) {
+      if (mlirIntegerTypeIsSignless(elementType) ||
+          mlirIntegerTypeIsSigned(elementType)) {
+        // i8
+        return bufferInfo<int8_t>(shapedType);
+      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+        // unsigned i8
+        return bufferInfo<uint8_t>(shapedType);
+      }
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 16) {
+      if (mlirIntegerTypeIsSignless(elementType) ||
+          mlirIntegerTypeIsSigned(elementType)) {
+        // i16
+        return bufferInfo<int16_t>(shapedType);
+      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+        // unsigned i16
+        return bufferInfo<uint16_t>(shapedType);
       }
     }
 
-    std::string message = "unimplemented array format.";
-    throw SetPyError(PyExc_ValueError, message);
+    // TODO: Currently crashes the program. Just returning an empty buffer
+    // for now.
+    // Reported as https://github.com/pybind/pybind11/issues/3336
+    // throw std::invalid_argument(
+    //     "unsupported data type for conversion to Python buffer");
+    return py::buffer_info();
   }
 
   static void bindDerived(ClassTy &c) {
     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
                     py::arg("array"), py::arg("signless") = true,
+                    py::arg("type") = py::none(), py::arg("shape") = py::none(),
                     py::arg("context") = py::none(),
-                    "Gets from a buffer or ndarray")
+                    kDenseElementsAttrGetDocstring)
         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
                     py::arg("shaped_type"), py::arg("element_attr"),
                     "Gets a DenseElementsAttr where all values are the same")
@@ -474,21 +583,6 @@
   }
 
 private:
-  template <typename ElementTy>
-  static MlirAttribute
-  bulkLoad(MlirContext context,
-           MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
-           MlirType mlirElementType, py::buffer_info &arrayInfo) {
-    SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
-                                  arrayInfo.shape.begin() + arrayInfo.ndim);
-    MlirAttribute encodingAttr = mlirAttributeGetNull();
-    auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
-                                              mlirElementType, encodingAttr);
-    intptr_t numElements = arrayInfo.size;
-    const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
-    return ctor(shapedType, numElements, contents);
-  }
-
   static bool isUnsignedIntegerFormat(const std::string &format) {
     if (format.empty())
       return false;
@@ -507,7 +601,7 @@
 
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
-                             Type (*value)(MlirAttribute, intptr_t)) {
+                             const char *explicitFormat = nullptr) {
     intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the data for the buffer_info.
     // Buffer is configured for read-only access below.
@@ -528,9 +622,14 @@
       strides.push_back(sizeof(Type) * strideFactor);
     }
     strides.push_back(sizeof(Type));
-    return py::buffer_info(data, sizeof(Type),
-                           py::format_descriptor<Type>::format(), rank, shape,
-                           strides, /*readonly=*/true);
+    std::string format;
+    if (explicitFormat) {
+      format = explicitFormat;
+    } else {
+      format = py::format_descriptor<Type>::format();
+    }
+    return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
+                           /*readonly=*/true);
   }
 }; // namespace