blob: f6b4532b1b6be4c1b5f1b161b5ce34f4c7424166 [file] [log] [blame]
//===- TransformInterpreter.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pybind classes for the transform dialect interpreter.
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Transform/Interpreter.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace {
struct PyMlirTransformOptions {
PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
PyMlirTransformOptions(PyMlirTransformOptions &&other) {
options = other.options;
other.options.ptr = nullptr;
}
PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
MlirTransformOptions options;
};
} // namespace
static void populateTransformInterpreterSubmodule(py::module &m) {
py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
.def(py::init())
.def_property(
"expensive_checks",
[](const PyMlirTransformOptions &self) {
return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
},
[](PyMlirTransformOptions &self, bool value) {
mlirTransformOptionsEnableExpensiveChecks(self.options, value);
})
.def_property(
"enforce_single_top_level_transform_op",
[](const PyMlirTransformOptions &self) {
return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
self.options);
},
[](PyMlirTransformOptions &self, bool value) {
mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
value);
});
m.def(
"apply_named_sequence",
[](MlirOperation payloadRoot, MlirOperation transformRoot,
MlirOperation transformModule, const PyMlirTransformOptions &options) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(transformRoot));
// Calling back into Python to invalidate everything under the payload
// root. This is awkward, but we don't have access to PyMlirContext
// object here otherwise.
py::object obj = py::cast(payloadRoot);
obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot, transformRoot, transformModule, options.options);
if (mlirLogicalResultIsSuccess(result))
return;
throw py::value_error(
"Failed to apply named transform sequence.\nDiagnostic message " +
scope.takeMessage());
},
py::arg("payload_root"), py::arg("transform_root"),
py::arg("transform_module"),
py::arg("transform_options") = PyMlirTransformOptions());
m.def(
"copy_symbols_and_merge_into",
[](MlirOperation target, MlirOperation other) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(target));
MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
if (mlirLogicalResultIsFailure(result)) {
throw py::value_error(
"Failed to merge symbols.\nDiagnostic message " +
scope.takeMessage());
}
},
py::arg("target"), py::arg("other"));
}
PYBIND11_MODULE(_mlirTransformInterpreter, m) {
m.doc() = "MLIR Transform dialect interpreter functionality.";
populateTransformInterpreterSubmodule(m);
}