| //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <cstdint> |
| #include <optional> |
| #include <pybind11/cast.h> |
| #include <pybind11/detail/common.h> |
| #include <pybind11/pybind11.h> |
| #include <pybind11/pytypes.h> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "IRModule.h" |
| #include "mlir-c/BuiltinAttributes.h" |
| #include "mlir-c/IR.h" |
| #include "mlir-c/Interfaces.h" |
| #include "mlir-c/Support.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| |
| namespace py = pybind11; |
| |
| namespace mlir { |
| namespace python { |
| |
| constexpr static const char *constructorDoc = |
| R"(Creates an interface from a given operation/opview object or from a |
| subclass of OpView. Raises ValueError if the operation does not implement the |
| interface.)"; |
| |
| constexpr static const char *operationDoc = |
| R"(Returns an Operation for which the interface was constructed.)"; |
| |
| constexpr static const char *opviewDoc = |
| R"(Returns an OpView subclass _instance_ for which the interface was |
| constructed)"; |
| |
| constexpr static const char *inferReturnTypesDoc = |
| R"(Given the arguments required to build an operation, attempts to infer |
| its return types. Raises ValueError on failure.)"; |
| |
| constexpr static const char *inferReturnTypeComponentsDoc = |
| R"(Given the arguments required to build an operation, attempts to infer |
| its return shaped type components. Raises ValueError on failure.)"; |
| |
| namespace { |
| |
| /// Takes in an optional ist of operands and converts them into a SmallVector |
| /// of MlirVlaues. Returns an empty SmallVector if the list is empty. |
| llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) { |
| llvm::SmallVector<MlirValue> mlirOperands; |
| |
| if (!operandList || operandList->empty()) { |
| return mlirOperands; |
| } |
| |
| // Note: as the list may contain other lists this may not be final size. |
| mlirOperands.reserve(operandList->size()); |
| for (const auto &&it : llvm::enumerate(*operandList)) { |
| if (it.value().is_none()) |
| continue; |
| |
| PyValue *val; |
| try { |
| val = py::cast<PyValue *>(it.value()); |
| if (!val) |
| throw py::cast_error(); |
| mlirOperands.push_back(val->get()); |
| continue; |
| } catch (py::cast_error &err) { |
| // Intentionally unhandled to try sequence below first. |
| (void)err; |
| } |
| |
| try { |
| auto vals = py::cast<py::sequence>(it.value()); |
| for (py::object v : vals) { |
| try { |
| val = py::cast<PyValue *>(v); |
| if (!val) |
| throw py::cast_error(); |
| mlirOperands.push_back(val->get()); |
| } catch (py::cast_error &err) { |
| throw py::value_error( |
| (llvm::Twine("Operand ") + llvm::Twine(it.index()) + |
| " must be a Value or Sequence of Values (" + err.what() + ")") |
| .str()); |
| } |
| } |
| continue; |
| } catch (py::cast_error &err) { |
| throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + |
| " must be a Value or Sequence of Values (" + |
| err.what() + ")") |
| .str()); |
| } |
| |
| throw py::cast_error(); |
| } |
| |
| return mlirOperands; |
| } |
| |
| /// Takes in an optional vector of PyRegions and returns a SmallVector of |
| /// MlirRegion. Returns an empty SmallVector if the list is empty. |
| llvm::SmallVector<MlirRegion> |
| wrapRegions(std::optional<std::vector<PyRegion>> regions) { |
| llvm::SmallVector<MlirRegion> mlirRegions; |
| |
| if (regions) { |
| mlirRegions.reserve(regions->size()); |
| for (PyRegion ®ion : *regions) { |
| mlirRegions.push_back(region); |
| } |
| } |
| |
| return mlirRegions; |
| } |
| |
| } // namespace |
| |
| /// CRTP base class for Python classes representing MLIR Op interfaces. |
| /// Interface hierarchies are flat so no base class is expected here. The |
| /// derived class is expected to define the following static fields: |
| /// - `const char *pyClassName` - the name of the Python class to create; |
| /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID |
| /// of the interface. |
| /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind |
| /// interface-specific methods. |
| /// |
| /// An interface class may be constructed from either an Operation/OpView object |
| /// or from a subclass of OpView. In the latter case, only the static interface |
| /// methods are available, similarly to calling ConcereteOp::staticMethod on the |
| /// C++ side. Implementations of concrete interfaces can use the `isStatic` |
| /// method to check whether the interface object was constructed from a class or |
| /// an operation/opview instance. The `getOpName` always succeeds and returns a |
| /// canonical name of the operation suitable for lookups. |
| template <typename ConcreteIface> |
| class PyConcreteOpInterface { |
| protected: |
| using ClassTy = py::class_<ConcreteIface>; |
| using GetTypeIDFunctionTy = MlirTypeID (*)(); |
| |
| public: |
| /// Constructs an interface instance from an object that is either an |
| /// operation or a subclass of OpView. In the latter case, only the static |
| /// methods of the interface are accessible to the caller. |
| PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) |
| : obj(std::move(object)) { |
| try { |
| operation = &py::cast<PyOperation &>(obj); |
| } catch (py::cast_error &) { |
| // Do nothing. |
| } |
| |
| try { |
| operation = &py::cast<PyOpView &>(obj).getOperation(); |
| } catch (py::cast_error &) { |
| // Do nothing. |
| } |
| |
| if (operation != nullptr) { |
| if (!mlirOperationImplementsInterface(*operation, |
| ConcreteIface::getInterfaceID())) { |
| std::string msg = "the operation does not implement "; |
| throw py::value_error(msg + ConcreteIface::pyClassName); |
| } |
| |
| MlirIdentifier identifier = mlirOperationGetName(*operation); |
| MlirStringRef stringRef = mlirIdentifierStr(identifier); |
| opName = std::string(stringRef.data, stringRef.length); |
| } else { |
| try { |
| opName = obj.attr("OPERATION_NAME").template cast<std::string>(); |
| } catch (py::cast_error &) { |
| throw py::type_error( |
| "Op interface does not refer to an operation or OpView class"); |
| } |
| |
| if (!mlirOperationImplementsInterfaceStatic( |
| mlirStringRefCreate(opName.data(), opName.length()), |
| context.resolve().get(), ConcreteIface::getInterfaceID())) { |
| std::string msg = "the operation does not implement "; |
| throw py::value_error(msg + ConcreteIface::pyClassName); |
| } |
| } |
| } |
| |
| /// Creates the Python bindings for this class in the given module. |
| static void bind(py::module &m) { |
| py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName, |
| py::module_local()); |
| cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"), |
| py::arg("context") = py::none(), constructorDoc) |
| .def_property_readonly("operation", |
| &PyConcreteOpInterface::getOperationObject, |
| operationDoc) |
| .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, |
| opviewDoc); |
| ConcreteIface::bindDerived(cls); |
| } |
| |
| /// Hook for derived classes to add class-specific bindings. |
| static void bindDerived(ClassTy &cls) {} |
| |
| /// Returns `true` if this object was constructed from a subclass of OpView |
| /// rather than from an operation instance. |
| bool isStatic() { return operation == nullptr; } |
| |
| /// Returns the operation instance from which this object was constructed. |
| /// Throws a type error if this object was constructed from a subclass of |
| /// OpView. |
| py::object getOperationObject() { |
| if (operation == nullptr) { |
| throw py::type_error("Cannot get an operation from a static interface"); |
| } |
| |
| return operation->getRef().releaseObject(); |
| } |
| |
| /// Returns the opview of the operation instance from which this object was |
| /// constructed. Throws a type error if this object was constructed form a |
| /// subclass of OpView. |
| py::object getOpView() { |
| if (operation == nullptr) { |
| throw py::type_error("Cannot get an opview from a static interface"); |
| } |
| |
| return operation->createOpView(); |
| } |
| |
| /// Returns the canonical name of the operation this interface is constructed |
| /// from. |
| const std::string &getOpName() { return opName; } |
| |
| private: |
| PyOperation *operation = nullptr; |
| std::string opName; |
| py::object obj; |
| }; |
| |
| /// Python wrapper for InferTypeOpInterface. This interface has only static |
| /// methods. |
| class PyInferTypeOpInterface |
| : public PyConcreteOpInterface<PyInferTypeOpInterface> { |
| public: |
| using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; |
| |
| constexpr static const char *pyClassName = "InferTypeOpInterface"; |
| constexpr static GetTypeIDFunctionTy getInterfaceID = |
| &mlirInferTypeOpInterfaceTypeID; |
| |
| /// C-style user-data structure for type appending callback. |
| struct AppendResultsCallbackData { |
| std::vector<PyType> &inferredTypes; |
| PyMlirContext &pyMlirContext; |
| }; |
| |
| /// Appends the types provided as the two first arguments to the user-data |
| /// structure (expects AppendResultsCallbackData). |
| static void appendResultsCallback(intptr_t nTypes, MlirType *types, |
| void *userData) { |
| auto *data = static_cast<AppendResultsCallbackData *>(userData); |
| data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); |
| for (intptr_t i = 0; i < nTypes; ++i) { |
| data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); |
| } |
| } |
| |
| /// Given the arguments required to build an operation, attempts to infer its |
| /// return types. Throws value_error on failure. |
| std::vector<PyType> |
| inferReturnTypes(std::optional<py::list> operandList, |
| std::optional<PyAttribute> attributes, void *properties, |
| std::optional<std::vector<PyRegion>> regions, |
| DefaultingPyMlirContext context, |
| DefaultingPyLocation location) { |
| llvm::SmallVector<MlirValue> mlirOperands = |
| wrapOperands(std::move(operandList)); |
| llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); |
| |
| std::vector<PyType> inferredTypes; |
| PyMlirContext &pyContext = context.resolve(); |
| AppendResultsCallbackData data{inferredTypes, pyContext}; |
| MlirStringRef opNameRef = |
| mlirStringRefCreate(getOpName().data(), getOpName().length()); |
| MlirAttribute attributeDict = |
| attributes ? attributes->get() : mlirAttributeGetNull(); |
| |
| MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( |
| opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), |
| mlirOperands.data(), attributeDict, properties, mlirRegions.size(), |
| mlirRegions.data(), &appendResultsCallback, &data); |
| |
| if (mlirLogicalResultIsFailure(result)) { |
| throw py::value_error("Failed to infer result types"); |
| } |
| |
| return inferredTypes; |
| } |
| |
| static void bindDerived(ClassTy &cls) { |
| cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, |
| py::arg("operands") = py::none(), |
| py::arg("attributes") = py::none(), |
| py::arg("properties") = py::none(), py::arg("regions") = py::none(), |
| py::arg("context") = py::none(), py::arg("loc") = py::none(), |
| inferReturnTypesDoc); |
| } |
| }; |
| |
| /// Wrapper around an shaped type components. |
| class PyShapedTypeComponents { |
| public: |
| PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} |
| PyShapedTypeComponents(py::list shape, MlirType elementType) |
| : shape(std::move(shape)), elementType(elementType), ranked(true) {} |
| PyShapedTypeComponents(py::list shape, MlirType elementType, |
| MlirAttribute attribute) |
| : shape(std::move(shape)), elementType(elementType), attribute(attribute), |
| ranked(true) {} |
| PyShapedTypeComponents(PyShapedTypeComponents &) = delete; |
| PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept |
| : shape(other.shape), elementType(other.elementType), |
| attribute(other.attribute), ranked(other.ranked) {} |
| |
| static void bind(py::module &m) { |
| py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents", |
| py::module_local()) |
| .def_property_readonly( |
| "element_type", |
| [](PyShapedTypeComponents &self) { return self.elementType; }, |
| "Returns the element type of the shaped type components.") |
| .def_static( |
| "get", |
| [](PyType &elementType) { |
| return PyShapedTypeComponents(elementType); |
| }, |
| py::arg("element_type"), |
| "Create an shaped type components object with only the element " |
| "type.") |
| .def_static( |
| "get", |
| [](py::list shape, PyType &elementType) { |
| return PyShapedTypeComponents(std::move(shape), elementType); |
| }, |
| py::arg("shape"), py::arg("element_type"), |
| "Create a ranked shaped type components object.") |
| .def_static( |
| "get", |
| [](py::list shape, PyType &elementType, PyAttribute &attribute) { |
| return PyShapedTypeComponents(std::move(shape), elementType, |
| attribute); |
| }, |
| py::arg("shape"), py::arg("element_type"), py::arg("attribute"), |
| "Create a ranked shaped type components object with attribute.") |
| .def_property_readonly( |
| "has_rank", |
| [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, |
| "Returns whether the given shaped type component is ranked.") |
| .def_property_readonly( |
| "rank", |
| [](PyShapedTypeComponents &self) -> py::object { |
| if (!self.ranked) { |
| return py::none(); |
| } |
| return py::int_(self.shape.size()); |
| }, |
| "Returns the rank of the given ranked shaped type components. If " |
| "the shaped type components does not have a rank, None is " |
| "returned.") |
| .def_property_readonly( |
| "shape", |
| [](PyShapedTypeComponents &self) -> py::object { |
| if (!self.ranked) { |
| return py::none(); |
| } |
| return py::list(self.shape); |
| }, |
| "Returns the shape of the ranked shaped type components as a list " |
| "of integers. Returns none if the shaped type component does not " |
| "have a rank."); |
| } |
| |
| pybind11::object getCapsule(); |
| static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); |
| |
| private: |
| py::list shape; |
| MlirType elementType; |
| MlirAttribute attribute; |
| bool ranked{false}; |
| }; |
| |
| /// Python wrapper for InferShapedTypeOpInterface. This interface has only |
| /// static methods. |
| class PyInferShapedTypeOpInterface |
| : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { |
| public: |
| using PyConcreteOpInterface< |
| PyInferShapedTypeOpInterface>::PyConcreteOpInterface; |
| |
| constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; |
| constexpr static GetTypeIDFunctionTy getInterfaceID = |
| &mlirInferShapedTypeOpInterfaceTypeID; |
| |
| /// C-style user-data structure for type appending callback. |
| struct AppendResultsCallbackData { |
| std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; |
| }; |
| |
| /// Appends the shaped type components provided as unpacked shape, element |
| /// type, attribute to the user-data. |
| static void appendResultsCallback(bool hasRank, intptr_t rank, |
| const int64_t *shape, MlirType elementType, |
| MlirAttribute attribute, void *userData) { |
| auto *data = static_cast<AppendResultsCallbackData *>(userData); |
| if (!hasRank) { |
| data->inferredShapedTypeComponents.emplace_back(elementType); |
| } else { |
| py::list shapeList; |
| for (intptr_t i = 0; i < rank; ++i) { |
| shapeList.append(shape[i]); |
| } |
| data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, |
| attribute); |
| } |
| } |
| |
| /// Given the arguments required to build an operation, attempts to infer the |
| /// shaped type components. Throws value_error on failure. |
| std::vector<PyShapedTypeComponents> inferReturnTypeComponents( |
| std::optional<py::list> operandList, |
| std::optional<PyAttribute> attributes, void *properties, |
| std::optional<std::vector<PyRegion>> regions, |
| DefaultingPyMlirContext context, DefaultingPyLocation location) { |
| llvm::SmallVector<MlirValue> mlirOperands = |
| wrapOperands(std::move(operandList)); |
| llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); |
| |
| std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; |
| PyMlirContext &pyContext = context.resolve(); |
| AppendResultsCallbackData data{inferredShapedTypeComponents}; |
| MlirStringRef opNameRef = |
| mlirStringRefCreate(getOpName().data(), getOpName().length()); |
| MlirAttribute attributeDict = |
| attributes ? attributes->get() : mlirAttributeGetNull(); |
| |
| MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( |
| opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), |
| mlirOperands.data(), attributeDict, properties, mlirRegions.size(), |
| mlirRegions.data(), &appendResultsCallback, &data); |
| |
| if (mlirLogicalResultIsFailure(result)) { |
| throw py::value_error("Failed to infer result shape type components"); |
| } |
| |
| return inferredShapedTypeComponents; |
| } |
| |
| static void bindDerived(ClassTy &cls) { |
| cls.def("inferReturnTypeComponents", |
| &PyInferShapedTypeOpInterface::inferReturnTypeComponents, |
| py::arg("operands") = py::none(), |
| py::arg("attributes") = py::none(), py::arg("regions") = py::none(), |
| py::arg("properties") = py::none(), py::arg("context") = py::none(), |
| py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); |
| } |
| }; |
| |
| void populateIRInterfaces(py::module &m) { |
| PyInferTypeOpInterface::bind(m); |
| PyShapedTypeComponents::bind(m); |
| PyInferShapedTypeOpInterface::bind(m); |
| } |
| |
| } // namespace python |
| } // namespace mlir |