[MLIR][Transform][Python] Expose applying named_sequences as a method (#168223)
Makes it so that a NamedSequenceOp can be directly applied to a Module,
via a method `apply(...)`.
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index de414dc..b3dd79c 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -7,6 +7,7 @@
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
+from . import interpreter
try:
from ...ir import *
@@ -324,6 +325,25 @@
def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
+ def apply(
+ self,
+ payload: Module,
+ transform_options: Optional[interpreter.TransformOptions] = None,
+ ) -> Module:
+ assert self.parent
+ assert "transform.with_named_sequence" in self.parent.attributes
+ assert isinstance(
+ self.parent.attributes["transform.with_named_sequence"], UnitAttr
+ )
+
+ interpreter.apply_named_sequence(
+ payload_root=payload,
+ transform_root=self,
+ transform_module=self.parent,
+ transform_options=transform_options,
+ )
+ return payload # NB: was modified in-place (if any transformation happened)
+
def named_sequence(
sym_name: Union[str, SymbolRefAttr],
diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
index 819a3be..ca9ce5d 100644
--- a/mlir/test/python/dialects/transform_interpreter.py
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -32,6 +32,20 @@
@test_in_context
+def print_self_via_apply_method():
+ m = ir.Module.parse(
+ print_root_module.replace("from interpreter", "print_self_via_apply_method")
+ )
+ m.body.operations[0].apply(m)
+
+
+# CHECK-LABEL: print_self_via_apply_method
+# CHECK: transform.named_sequence @__transform_main
+# CHECK: transform.print
+# CHECK: transform.yield
+
+
+@test_in_context
def print_other():
transform = ir.Module.parse(
print_root_module.replace("from interpreter", "print_other")