| //===- Rewrite.cpp - Rewrite ----------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Rewrite.h" |
| |
| #include "mlir-c/Bindings/Python/Interop.h" |
| #include "mlir-c/IR.h" |
| #include "mlir-c/Rewrite.h" |
| #include "mlir-c/Support.h" |
| #include "mlir/Bindings/Python/Globals.h" |
| #include "mlir/Bindings/Python/IRCore.h" |
| #include "mlir/Config/mlir-config.h" |
| #include "nanobind/nanobind.h" |
| #include <type_traits> |
| |
| namespace nb = nanobind; |
| using namespace mlir; |
| using namespace nb::literals; |
| using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; |
| |
| namespace mlir { |
| namespace python { |
| namespace MLIR_BINDINGS_PYTHON_DOMAIN { |
| |
| // Convert the Python object to a boolean. |
| // If it evaluates to False, treat it as success; |
| // otherwise, treat it as failure. |
| // Note that None is considered success. |
| static MlirLogicalResult logicalResultFromObject(const nb::object &obj) { |
| if (obj.is_none()) |
| return mlirLogicalResultSuccess(); |
| |
| return nb::cast<bool>(obj) ? mlirLogicalResultFailure() |
| : mlirLogicalResultSuccess(); |
| } |
| |
| static std::string operationNameFromObject(nb::handle root) { |
| if (root.is_type()) |
| return nb::cast<std::string>(root.attr("OPERATION_NAME")); |
| if (nb::isinstance<nb::str>(root)) |
| return nb::cast<std::string>(root); |
| |
| throw nb::type_error("the root argument must be a type or a string"); |
| } |
| |
| static std::string dialectNameFromObject(nb::handle root) { |
| if (root.is_type()) |
| return nb::cast<std::string>(root.attr("DIALECT_NAMESPACE")); |
| if (nb::isinstance<nb::str>(root)) |
| return nb::cast<std::string>(root); |
| |
| throw nb::type_error("the root argument must be a type or a string"); |
| } |
| |
| class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> { |
| public: |
| static constexpr const char *pyClassName = "PatternRewriter"; |
| |
| PyPatternRewriter(MlirPatternRewriter rewriter) |
| : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {} |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // PyRewritePatternSet |
| //===----------------------------------------------------------------------===// |
| |
| PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx) |
| : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {} |
| |
| PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns) |
| : patterns(patterns), owned(false) {} |
| |
| PyRewritePatternSet::~PyRewritePatternSet() { |
| if (owned && patterns.ptr) |
| mlirRewritePatternSetDestroy(patterns); |
| } |
| |
| MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; } |
| |
| bool PyRewritePatternSet::isOwned() const { return owned; } |
| |
| void PyRewritePatternSet::add(nb::handle root, |
| const nb::callable &matchAndRewrite, |
| unsigned benefit) { |
| std::string opName = operationNameFromObject(root); |
| MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size()); |
| |
| MlirRewritePatternCallbacks callbacks; |
| callbacks.construct = [](void *userData) { |
| nb::handle(static_cast<PyObject *>(userData)).inc_ref(); |
| }; |
| callbacks.destruct = [](void *userData) { |
| nb::handle(static_cast<PyObject *>(userData)).dec_ref(); |
| }; |
| callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, |
| MlirPatternRewriter rewriter, |
| void *userData) -> MlirLogicalResult { |
| nb::handle f(static_cast<PyObject *>(userData)); |
| |
| PyMlirContextRef context = |
| PyMlirContext::forContext(mlirOperationGetContext(op)); |
| nb::object opView = PyOperation::forOperation(context, op)->createOpView(); |
| |
| nb::object res = f(opView, PyPatternRewriter(rewriter)); |
| return logicalResultFromObject(res); |
| }; |
| |
| MlirRewritePattern pattern = mlirOpRewritePatternCreate( |
| rootName, benefit, mlirRewritePatternSetGetContext(patterns), callbacks, |
| matchAndRewrite.ptr(), |
| /* nGeneratedNames */ 0, |
| /* generatedNames */ nullptr); |
| mlirRewritePatternSetAdd(patterns, pattern); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PyConversionPatternRewriter |
| //===----------------------------------------------------------------------===// |
| |
| class PyConversionPatternRewriter : public PyPatternRewriter { |
| public: |
| PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter) |
| : PyPatternRewriter( |
| mlirConversionPatternRewriterAsPatternRewriter(rewriter)), |
| rewriter(rewriter) {} |
| |
| MlirConversionPatternRewriter rewriter; |
| }; |
| |
| class PyConversionTarget { |
| public: |
| PyConversionTarget(MlirContext context) |
| : target(mlirConversionTargetCreate(context)) {} |
| ~PyConversionTarget() { mlirConversionTargetDestroy(target); } |
| |
| void addLegalOp(const std::string &opName) { |
| mlirConversionTargetAddLegalOp( |
| target, mlirStringRefCreate(opName.data(), opName.size())); |
| } |
| |
| void addIllegalOp(const std::string &opName) { |
| mlirConversionTargetAddIllegalOp( |
| target, mlirStringRefCreate(opName.data(), opName.size())); |
| } |
| |
| void addLegalDialect(const std::string &dialectName) { |
| mlirConversionTargetAddLegalDialect( |
| target, mlirStringRefCreate(dialectName.data(), dialectName.size())); |
| } |
| |
| void addIllegalDialect(const std::string &dialectName) { |
| mlirConversionTargetAddIllegalDialect( |
| target, mlirStringRefCreate(dialectName.data(), dialectName.size())); |
| } |
| |
| MlirConversionTarget get() { return target; } |
| |
| private: |
| MlirConversionTarget target; |
| }; |
| |
| class PyTypeConverter { |
| public: |
| PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {} |
| PyTypeConverter(MlirTypeConverter typeConverter) |
| : typeConverter(typeConverter), owner(false) {} |
| ~PyTypeConverter() { |
| if (owner) |
| mlirTypeConverterDestroy(typeConverter); |
| } |
| |
| void addConversion(const nb::callable &convert) { |
| mlirTypeConverterAddConversion( |
| typeConverter, |
| [](MlirType type, MlirType *converted, |
| void *userData) -> MlirLogicalResult { |
| nb::handle f = nb::handle(static_cast<PyObject *>(userData)); |
| auto ctx = PyMlirContext::forContext(mlirTypeGetContext(type)); |
| nb::object res = f(PyType(ctx, type).maybeDownCast()); |
| if (res.is_none()) |
| return mlirLogicalResultFailure(); |
| |
| *converted = nb::cast<PyType>(res).get(); |
| return mlirLogicalResultSuccess(); |
| }, |
| convert.ptr()); |
| } |
| |
| nb::typed<nb::object, std::optional<PyType>> convertType(PyType &type) { |
| MlirType converted = mlirTypeConverterConvertType(typeConverter, type); |
| if (mlirTypeIsNull(converted)) |
| return nb::none(); |
| return PyType(PyMlirContext::forContext(mlirTypeGetContext(converted)), |
| converted) |
| .maybeDownCast(); |
| } |
| |
| MlirTypeConverter get() { return typeConverter; } |
| |
| private: |
| MlirTypeConverter typeConverter; |
| bool owner; |
| }; |
| |
| class PyConversionPattern { |
| public: |
| PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {} |
| |
| PyTypeConverter getTypeConverter() { |
| return PyTypeConverter(mlirConversionPatternGetTypeConverter(pattern)); |
| } |
| |
| private: |
| MlirConversionPattern pattern; |
| }; |
| |
| void PyRewritePatternSet::addConversion(nb::handle root, |
| const nb::callable &matchAndRewrite, |
| PyTypeConverter &typeConverter, |
| unsigned benefit) { |
| std::string opName = operationNameFromObject(root); |
| MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size()); |
| |
| MlirConversionPatternCallbacks callbacks; |
| callbacks.construct = [](void *userData) { |
| nb::handle(static_cast<PyObject *>(userData)).inc_ref(); |
| }; |
| callbacks.destruct = [](void *userData) { |
| nb::handle(static_cast<PyObject *>(userData)).dec_ref(); |
| }; |
| callbacks.matchAndRewrite = |
| [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands, |
| MlirValue *operands, MlirConversionPatternRewriter rewriter, |
| void *userData) -> MlirLogicalResult { |
| nb::handle f(static_cast<PyObject *>(userData)); |
| |
| PyMlirContextRef ctx = |
| PyMlirContext::forContext(mlirOperationGetContext(op)); |
| nb::object opView = PyOperation::forOperation(ctx, op)->createOpView(); |
| |
| std::vector<MlirValue> operandsVec(operands, operands + nOperands); |
| nb::object adaptorCls = |
| PyGlobals::get() |
| .lookupOpAdaptorClass([&] { |
| MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op)); |
| return std::string_view(ref.data, ref.length); |
| }()) |
| .value_or(nb::borrow(nb::type<PyOpAdaptor>())); |
| |
| nb::object res = f(opView, adaptorCls(operandsVec, opView), |
| PyConversionPattern(pattern).getTypeConverter(), |
| PyConversionPatternRewriter(rewriter)); |
| return logicalResultFromObject(res); |
| }; |
| MlirConversionPattern pattern = mlirOpConversionPatternCreate( |
| rootName, benefit, mlirRewritePatternSetGetContext(patterns), |
| typeConverter.get(), callbacks, matchAndRewrite.ptr(), |
| /* nGeneratedNames */ 0, |
| /* generatedNames */ nullptr); |
| mlirRewritePatternSetAdd(patterns, |
| mlirConversionPatternAsRewritePattern(pattern)); |
| } |
| |
| #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| struct PyMlirPDLResultList : MlirPDLResultList {}; |
| |
| static nb::object objectFromPDLValue(MlirPDLValue value) { |
| if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) |
| return nb::cast(v); |
| if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v)) |
| return nb::cast(v); |
| if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v)) |
| return nb::cast(v); |
| if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v)) |
| return nb::cast(v); |
| |
| throw std::runtime_error("unsupported PDL value type"); |
| } |
| |
| static std::vector<nb::object> objectsFromPDLValues(size_t nValues, |
| MlirPDLValue *values) { |
| std::vector<nb::object> args; |
| args.reserve(nValues); |
| for (size_t i = 0; i < nValues; ++i) |
| args.push_back(objectFromPDLValue(values[i])); |
| return args; |
| } |
| |
| /// Owning Wrapper around a PDLPatternModule. |
| class PyPDLPatternModule { |
| public: |
| PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} |
| PyPDLPatternModule(PyPDLPatternModule &&other) noexcept |
| : module(other.module) { |
| other.module.ptr = nullptr; |
| } |
| ~PyPDLPatternModule() { |
| if (module.ptr != nullptr) |
| mlirPDLPatternModuleDestroy(module); |
| } |
| MlirPDLPatternModule get() { return module; } |
| |
| void registerRewriteFunction(const std::string &name, |
| const nb::callable &fn) { |
| mlirPDLPatternModuleRegisterRewriteFunction( |
| get(), mlirStringRefCreate(name.data(), name.size()), |
| [](MlirPatternRewriter rewriter, MlirPDLResultList results, |
| size_t nValues, MlirPDLValue *values, |
| void *userData) -> MlirLogicalResult { |
| nb::handle f = nb::handle(static_cast<PyObject *>(userData)); |
| return logicalResultFromObject( |
| f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, |
| objectsFromPDLValues(nValues, values))); |
| }, |
| fn.ptr()); |
| } |
| |
| void registerConstraintFunction(const std::string &name, |
| const nb::callable &fn) { |
| mlirPDLPatternModuleRegisterConstraintFunction( |
| get(), mlirStringRefCreate(name.data(), name.size()), |
| [](MlirPatternRewriter rewriter, MlirPDLResultList results, |
| size_t nValues, MlirPDLValue *values, |
| void *userData) -> MlirLogicalResult { |
| nb::handle f = nb::handle(static_cast<PyObject *>(userData)); |
| return logicalResultFromObject( |
| f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, |
| objectsFromPDLValues(nValues, values))); |
| }, |
| fn.ptr()); |
| } |
| |
| private: |
| MlirPDLPatternModule module; |
| }; |
| #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| |
| /// Owning Wrapper around a FrozenRewritePatternSet. |
| class PyFrozenRewritePatternSet { |
| public: |
| PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} |
| PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept |
| : set(other.set) { |
| other.set.ptr = nullptr; |
| } |
| ~PyFrozenRewritePatternSet() { |
| if (set.ptr != nullptr) |
| mlirFrozenRewritePatternSetDestroy(set); |
| } |
| MlirFrozenRewritePatternSet get() { return set; } |
| |
| nb::object getCapsule() { |
| return nb::steal<nb::object>( |
| mlirPythonFrozenRewritePatternSetToCapsule(get())); |
| } |
| |
| static nb::object createFromCapsule(const nb::object &capsule) { |
| MlirFrozenRewritePatternSet rawPm = |
| mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); |
| if (rawPm.ptr == nullptr) |
| throw nb::python_error(); |
| return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); |
| } |
| |
| private: |
| MlirFrozenRewritePatternSet set; |
| }; |
| |
| void PyRewritePatternSet::bind(nb::module_ &m) { |
| nb::class_<PyRewritePatternSet>(m, "RewritePatternSet") |
| .def( |
| "__init__", |
| [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { |
| new (&self) PyRewritePatternSet(context.get()->get()); |
| }, |
| "context"_a = nb::none()) |
| .def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"), |
| nb::arg("benefit") = 1, |
| R"(Add a new rewrite pattern on the specified root operation, using |
| the provided callable for matching and rewriting, and assign it |
| the given benefit. |
| |
| Args: |
| root: The root operation to which this pattern applies. This may |
| be either an OpView subclass or an operation name. |
| fn: The callable to use for matching and rewriting, which takes |
| an operation and a pattern rewriter. The match is considered |
| successful iff the callable returns a falsy value. |
| benefit: The benefit of the pattern, defaulting to 1.)") |
| .def("add_conversion", &PyRewritePatternSet::addConversion, |
| nb::arg("root"), nb::arg("fn"), nb::arg("type_converter"), |
| nb::arg("benefit") = 1, |
| R"( |
| Add a new conversion pattern on the specified root operation, |
| using the provided callable for matching and rewriting, |
| and assign it the given benefit. |
| |
| Args: |
| root: The root operation to which this pattern applies. |
| This may be either an OpView subclass or an operation name. |
| fn: The callable to use for matching and rewriting, which takes an |
| operation, its adaptor, the type converter and a pattern |
| rewriter. The match is considered successful iff the callable |
| returns a falsy value. |
| type_converter: The type converter to convert types in the IR. |
| benefit: The benefit of the pattern, defaulting to 1.)") |
| .def( |
| "freeze", |
| [](PyRewritePatternSet &self) { |
| if (!self.isOwned()) |
| throw std::runtime_error( |
| "cannot freeze a non-owning pattern set"); |
| MlirRewritePatternSet s = self.get(); |
| return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(s)); |
| }, |
| "Freeze the pattern set into a frozen one."); |
| } |
| |
| enum class PyGreedyRewriteStrictness : std::underlying_type_t< |
| MlirGreedyRewriteStrictness> { |
| ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP, |
| EXISTING_AND_NEW_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS, |
| EXISTING_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS, |
| }; |
| |
| enum class PyGreedySimplifyRegionLevel : std::underlying_type_t< |
| MlirGreedySimplifyRegionLevel> { |
| DISABLED = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED, |
| NORMAL = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL, |
| AGGRESSIVE = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE |
| }; |
| |
| /// Owning Wrapper around a GreedyRewriteDriverConfig. |
| class PyGreedyRewriteConfig { |
| public: |
| PyGreedyRewriteConfig() |
| : config(mlirGreedyRewriteDriverConfigCreate().ptr, |
| PyGreedyRewriteConfig::customDeleter) {} |
| PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept |
| : config(std::move(other.config)) {} |
| PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept |
| : config(other.config) {} |
| |
| MlirGreedyRewriteDriverConfig get() { |
| return MlirGreedyRewriteDriverConfig{config.get()}; |
| } |
| |
| void setMaxIterations(int64_t maxIterations) { |
| mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations); |
| } |
| |
| void setMaxNumRewrites(int64_t maxNumRewrites) { |
| mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites); |
| } |
| |
| void setUseTopDownTraversal(bool useTopDownTraversal) { |
| mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(), |
| useTopDownTraversal); |
| } |
| |
| void enableFolding(bool enable) { |
| mlirGreedyRewriteDriverConfigEnableFolding(get(), enable); |
| } |
| |
| void setStrictness(PyGreedyRewriteStrictness strictness) { |
| mlirGreedyRewriteDriverConfigSetStrictness( |
| get(), static_cast<MlirGreedyRewriteStrictness>(strictness)); |
| } |
| |
| void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) { |
| mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel( |
| get(), static_cast<MlirGreedySimplifyRegionLevel>(level)); |
| } |
| |
| void enableConstantCSE(bool enable) { |
| mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable); |
| } |
| |
| int64_t getMaxIterations() { |
| return mlirGreedyRewriteDriverConfigGetMaxIterations(get()); |
| } |
| |
| int64_t getMaxNumRewrites() { |
| return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get()); |
| } |
| |
| bool getUseTopDownTraversal() { |
| return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get()); |
| } |
| |
| bool isFoldingEnabled() { |
| return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get()); |
| } |
| |
| PyGreedyRewriteStrictness getStrictness() { |
| return static_cast<PyGreedyRewriteStrictness>( |
| mlirGreedyRewriteDriverConfigGetStrictness(get())); |
| } |
| |
| PyGreedySimplifyRegionLevel getRegionSimplificationLevel() { |
| return static_cast<PyGreedySimplifyRegionLevel>( |
| mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get())); |
| } |
| |
| bool isConstantCSEEnabled() { |
| return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get()); |
| } |
| |
| private: |
| std::shared_ptr<void> config; |
| static void customDeleter(void *c) { |
| mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c}); |
| } |
| }; |
| |
| enum class PyDialectConversionFoldingMode : std::underlying_type_t< |
| MlirDialectConversionFoldingMode> { |
| Never = MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER, |
| BeforePatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS, |
| AfterPatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS, |
| }; |
| |
| class PyConversionConfig { |
| public: |
| PyConversionConfig() |
| : config(mlirConversionConfigCreate().ptr, |
| PyConversionConfig::customDeleter) {} |
| |
| MlirConversionConfig get() { return MlirConversionConfig{config.get()}; } |
| |
| void setFoldingMode(PyDialectConversionFoldingMode mode) { |
| mlirConversionConfigSetFoldingMode(get(), |
| MlirDialectConversionFoldingMode(mode)); |
| } |
| |
| PyDialectConversionFoldingMode getFoldingMode() { |
| return PyDialectConversionFoldingMode( |
| mlirConversionConfigGetFoldingMode(get())); |
| } |
| |
| void enableBuildMaterializations(bool enabled) { |
| mlirConversionConfigEnableBuildMaterializations(get(), enabled); |
| } |
| |
| bool isBuildMaterializationsEnabled() { |
| return mlirConversionConfigIsBuildMaterializationsEnabled(get()); |
| } |
| |
| private: |
| std::shared_ptr<void> config; |
| static void customDeleter(void *c) { |
| mlirConversionConfigDestroy(MlirConversionConfig{c}); |
| } |
| }; |
| |
| /// Create the `mlir.rewrite` here. |
| void populateRewriteSubmodule(nb::module_ &m) { |
| // Enum definitions |
| nb::enum_<PyGreedyRewriteStrictness>(m, "GreedyRewriteStrictness") |
| .value("ANY_OP", PyGreedyRewriteStrictness::ANY_OP) |
| .value("EXISTING_AND_NEW_OPS", |
| PyGreedyRewriteStrictness::EXISTING_AND_NEW_OPS) |
| .value("EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS); |
| |
| nb::enum_<PyGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel") |
| .value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED) |
| .value("NORMAL", PyGreedySimplifyRegionLevel::NORMAL) |
| .value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE); |
| |
| nb::enum_<PyDialectConversionFoldingMode>(m, "DialectConversionFoldingMode") |
| .value("NEVER", PyDialectConversionFoldingMode::Never) |
| .value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns) |
| .value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of the PatternRewriter |
| //---------------------------------------------------------------------------- |
| |
| PyPatternRewriter::bind(m); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of the RewritePatternSet |
| //---------------------------------------------------------------------------- |
| PyRewritePatternSet::bind(m); |
| |
| nb::class_<PyConversionPatternRewriter, PyPatternRewriter>( |
| m, "ConversionPatternRewriter") |
| .def("convert_region_types", |
| [](PyConversionPatternRewriter &self, PyRegion ®ion, |
| PyTypeConverter &typeConverter) { |
| mlirConversionPatternRewriterConvertRegionTypes( |
| self.rewriter, region.get(), typeConverter.get()); |
| }); |
| |
| nb::class_<PyConversionTarget>(m, "ConversionTarget") |
| .def( |
| "__init__", |
| [](PyConversionTarget &self, DefaultingPyMlirContext context) { |
| new (&self) PyConversionTarget(context.get()->get()); |
| }, |
| "context"_a = nb::none()) |
| .def( |
| "add_legal_op", |
| [](PyConversionTarget &self, const nb::args &ops) { |
| for (auto op : ops) { |
| self.addLegalOp(operationNameFromObject(op)); |
| } |
| }, |
| "ops"_a, "Mark the given operations as legal.") |
| .def( |
| "add_illegal_op", |
| [](PyConversionTarget &self, const nb::args &ops) { |
| for (auto op : ops) { |
| self.addIllegalOp(operationNameFromObject(op)); |
| } |
| }, |
| "ops"_a, "Mark the given operations as illegal.") |
| .def( |
| "add_legal_dialect", |
| [](PyConversionTarget &self, const nb::args &dialects) { |
| for (auto dialect : dialects) { |
| self.addLegalDialect(dialectNameFromObject(dialect)); |
| } |
| }, |
| "dialects"_a, "Mark the given dialects as legal.") |
| .def( |
| "add_illegal_dialect", |
| [](PyConversionTarget &self, const nb::args &dialects) { |
| for (auto dialect : dialects) { |
| self.addIllegalDialect(dialectNameFromObject(dialect)); |
| } |
| }, |
| "dialects"_a, "Mark the given dialect as illegal."); |
| |
| nb::class_<PyTypeConverter>(m, "TypeConverter") |
| .def(nb::init<>(), "Create a new TypeConverter.") |
| .def("add_conversion", &PyTypeConverter::addConversion, "convert"_a, |
| nb::keep_alive<0, 1>(), "Register a type conversion function.") |
| .def("convert_type", &PyTypeConverter::convertType, "type"_a, |
| "Convert the given type. Returns None if conversion fails."); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of the PDLResultList and PDLModule |
| //---------------------------------------------------------------------------- |
| #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| nb::class_<PyMlirPDLResultList>(m, "PDLResultList") |
| .def("append", |
| [](PyMlirPDLResultList results, const PyValue &value) { |
| mlirPDLResultListPushBackValue(results, value); |
| }) |
| .def("append", |
| [](PyMlirPDLResultList results, const PyOperation &op) { |
| mlirPDLResultListPushBackOperation(results, op); |
| }) |
| .def("append", |
| [](PyMlirPDLResultList results, const PyType &type) { |
| mlirPDLResultListPushBackType(results, type); |
| }) |
| .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) { |
| mlirPDLResultListPushBackAttribute(results, attr); |
| }); |
| nb::class_<PyPDLPatternModule>(m, "PDLModule") |
| .def( |
| "__init__", |
| [](PyPDLPatternModule &self, PyModule &module) { |
| new (&self) PyPDLPatternModule( |
| mlirPDLPatternModuleFromModule(module.get())); |
| }, |
| "module"_a, "Create a PDL module from the given module.") |
| .def( |
| "__init__", |
| [](PyPDLPatternModule &self, PyModule &module) { |
| new (&self) PyPDLPatternModule( |
| mlirPDLPatternModuleFromModule(module.get())); |
| }, |
| "module"_a, "Create a PDL module from the given module.") |
| .def( |
| "freeze", |
| [](PyPDLPatternModule &self) { |
| return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( |
| mlirRewritePatternSetFromPDLPatternModule(self.get()))); |
| }, |
| nb::keep_alive<0, 1>()) |
| .def( |
| "register_rewrite_function", |
| [](PyPDLPatternModule &self, const std::string &name, |
| const nb::callable &fn) { |
| self.registerRewriteFunction(name, fn); |
| }, |
| nb::keep_alive<1, 3>()) |
| .def( |
| "register_constraint_function", |
| [](PyPDLPatternModule &self, const std::string &name, |
| const nb::callable &fn) { |
| self.registerConstraintFunction(name, fn); |
| }, |
| nb::keep_alive<1, 3>()); |
| #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| |
| nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig") |
| .def(nb::init<>(), "Create a greedy rewrite driver config with defaults") |
| .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations, |
| &PyGreedyRewriteConfig::setMaxIterations, |
| "Maximum number of iterations") |
| .def_prop_rw("max_num_rewrites", |
| &PyGreedyRewriteConfig::getMaxNumRewrites, |
| &PyGreedyRewriteConfig::setMaxNumRewrites, |
| "Maximum number of rewrites per iteration") |
| .def_prop_rw("use_top_down_traversal", |
| &PyGreedyRewriteConfig::getUseTopDownTraversal, |
| &PyGreedyRewriteConfig::setUseTopDownTraversal, |
| "Whether to use top-down traversal") |
| .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled, |
| &PyGreedyRewriteConfig::enableFolding, |
| "Enable or disable folding") |
| .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness, |
| &PyGreedyRewriteConfig::setStrictness, |
| "Rewrite strictness level") |
| .def_prop_rw("region_simplification_level", |
| &PyGreedyRewriteConfig::getRegionSimplificationLevel, |
| &PyGreedyRewriteConfig::setRegionSimplificationLevel, |
| "Region simplification level") |
| .def_prop_rw("enable_constant_cse", |
| &PyGreedyRewriteConfig::isConstantCSEEnabled, |
| &PyGreedyRewriteConfig::enableConstantCSE, |
| "Enable or disable constant CSE"); |
| |
| nb::class_<PyConversionConfig>(m, "ConversionConfig") |
| .def(nb::init<>(), "Create a conversion config with defaults") |
| .def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode, |
| &PyConversionConfig::setFoldingMode, |
| "folding behavior during dialect conversion") |
| .def_prop_rw("build_materializations", |
| &PyConversionConfig::isBuildMaterializationsEnabled, |
| &PyConversionConfig::enableBuildMaterializations, |
| "Whether the dialect conversion attempts to build " |
| "source/target materializations"); |
| |
| nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet") |
| .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, |
| &PyFrozenRewritePatternSet::getCapsule) |
| .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, |
| &PyFrozenRewritePatternSet::createFromCapsule); |
| m.def( |
| "apply_patterns_and_fold_greedily", |
| [](PyModule &module, PyFrozenRewritePatternSet &set, |
| std::optional<PyGreedyRewriteConfig> config) { |
| MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily( |
| module.get(), set.get(), |
| config.has_value() ? config->get() |
| : mlirGreedyRewriteDriverConfigCreate()); |
| if (mlirLogicalResultIsFailure(status)) |
| throw std::runtime_error("pattern application failed to converge"); |
| }, |
| "module"_a, "set"_a, "config"_a = nb::none(), |
| "Applys the given patterns to the given module greedily while folding " |
| "results.") |
| .def( |
| "apply_patterns_and_fold_greedily", |
| [](PyOperationBase &op, PyFrozenRewritePatternSet &set, |
| std::optional<PyGreedyRewriteConfig> config) { |
| MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp( |
| op.getOperation(), set.get(), |
| config.has_value() ? config->get() |
| : mlirGreedyRewriteDriverConfigCreate()); |
| if (mlirLogicalResultIsFailure(status)) |
| throw std::runtime_error( |
| "pattern application failed to converge"); |
| }, |
| "op"_a, "set"_a, "config"_a = nb::none(), |
| "Applys the given patterns to the given op greedily while folding " |
| "results.") |
| .def( |
| "walk_and_apply_patterns", |
| [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { |
| mlirWalkAndApplyPatterns(op.getOperation(), set.get()); |
| }, |
| "op"_a, "set"_a, |
| "Applies the given patterns to the given op by a fast walk-based " |
| "driver.") |
| .def( |
| "apply_partial_conversion", |
| [](PyOperationBase &op, PyConversionTarget &target, |
| PyFrozenRewritePatternSet &set, |
| std::optional<PyConversionConfig> config) { |
| if (!config) |
| config.emplace(PyConversionConfig()); |
| PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); |
| MlirLogicalResult status = mlirApplyPartialConversion( |
| op.getOperation(), target.get(), set.get(), config->get()); |
| if (mlirLogicalResultIsFailure(status)) |
| throw MLIRError("partial conversion failed", errors.take()); |
| }, |
| "op"_a, "target"_a, "set"_a, "config"_a = nb::none(), |
| "Applies a partial conversion on the given operation.") |
| .def( |
| "apply_full_conversion", |
| [](PyOperationBase &op, PyConversionTarget &target, |
| PyFrozenRewritePatternSet &set, |
| std::optional<PyConversionConfig> config) { |
| if (!config) |
| config.emplace(PyConversionConfig()); |
| PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); |
| MlirLogicalResult status = mlirApplyFullConversion( |
| op.getOperation(), target.get(), set.get(), config->get()); |
| if (mlirLogicalResultIsFailure(status)) |
| throw MLIRError("full conversion failed", errors.take()); |
| }, |
| "op"_a, "target"_a, "set"_a, "config"_a = nb::none(), |
| "Applies a full conversion on the given operation."); |
| } |
| } // namespace MLIR_BINDINGS_PYTHON_DOMAIN |
| } // namespace python |
| } // namespace mlir |