blob: 740c49f76a26c43050d8786b18763826be66e6fe [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
from mlir import ir
from mlir.dialects.transform import interpreter as interp
def test_in_context(f):
with ir.Context(), ir.Location.unknown():
f()
return f
print_root_module = """
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op) {
transform.print %root { name = \"from interpreter\" }: !transform.any_op
transform.yield
}
}"""
@test_in_context
def print_self():
m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
interp.apply_named_sequence(m, m.body.operations[0], m)
# CHECK-LABEL: print_self
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.print
# CHECK: transform.yield
@test_in_context
def print_other():
transform = ir.Module.parse(
print_root_module.replace("from interpreter", "print_other")
)
payload = ir.Module.parse("module attributes { this.is.payload } {}")
interp.apply_named_sequence(payload, transform.body.operations[0], transform)
# CHECK-LABEL: print_other
# CHECK-NOT: transform
# CHECK: this.is.payload
@test_in_context
def failed():
payload = ir.Module.parse("module attributes { this.is.payload } {}")
try:
interp.apply_named_sequence(payload, payload, payload)
except ValueError as e:
assert (
"must implement TransformOpInterface to be used as transform root" in str(e)
)