blob: ea47f170cb632122cd381d92527c4395f2614822 [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
from typing import Callable
from mlir import ir
from mlir.dialects import scf, pdl
from mlir.dialects.transform import (
structured,
get_parent_op,
apply_patterns_canonicalization,
apply_cse,
any_op_t,
)
from mlir.dialects.transform import FailurePropagationMode
from mlir.dialects.transform.structured import structured_match
from mlir.dialects.transform.loop import loop_unroll
from mlir.dialects.transform.extras import (
constant_param,
OpHandle,
insert_transform_script,
sequence,
apply_patterns,
)
from mlir.extras import types as T
def construct_and_print_in_module(f):
print("\nTEST:", f.__name__)
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
f()
print(module)
return f
def build_transform_script(script: Callable[[OpHandle], None]):
print("\nTEST:", script.__name__)
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
insert_transform_script(module.body, script=script, dump_script=True)
module.operation.verify()
def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]):
print("\nTEST:", script.__name__)
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
insert_transform_script(
ir.InsertionPoint.at_block_begin(module.body),
script=script,
dump_script=True,
)
module.operation.verify()
# CHECK-LABEL: TEST: test_build_script_at_insertion_point
@build_transform_script_at_insertion_point
def test_build_script_at_insertion_point(op: OpHandle):
pass
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: transform.yield
# CHECK-NEXT: }
# CHECK-LABEL: TEST: test_constant_param_int
@build_transform_script
def test_constant_param_int(_: OpHandle):
constant_param(ir.IntegerAttr.get(T.i32(), 42))
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32
# CHECK-SAME: !transform.param<i32>
# CHECK-LABEL: TEST: test_constant_param_py_int
@build_transform_script
def test_constant_param_py_int(_: OpHandle):
constant_param(42)
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64
# CHECK-SAME: !transform.param<i64>
# CHECK-LABEL: TEST: test_constant_param_symbol_attr
@build_transform_script
def test_constant_param_symbol_attr(_: OpHandle):
constant_param(ir.SymbolRefAttr.get(["symbol"]))
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol
# CHECK-SAME: !transform.any_param
# CHECK-LABEL: TEST: test_constant_param_type
@build_transform_script
def test_constant_param_type(_: OpHandle):
constant_param(ir.TypeAttr.get(T.i32()))
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32
# CHECK-SAME: !transform.any_param
# CHECK-LABEL: TEST: test_get_defining_op
@build_transform_script
def test_get_defining_op(op: OpHandle):
op.get_result().get_defining_op()
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
# CHECK-SAME: !transform.any_value
# CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]]
# CHECK-LABEL: TEST: test_get_result
@build_transform_script
def test_get_result(op: OpHandle):
op.get_result()
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
# CHECK-LABEL: TEST: test_match_ops_single
@build_transform_script
def test_match_ops_single(op: OpHandle):
op.match_ops(scf.ForOp)
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
# CHECK-SAME: in %[[VAL_0]]
# CHECK-SAME: -> !transform.op<"scf.for">
# CHECK-LABEL: TEST: test_match_ops_string_name
@build_transform_script
def test_match_ops_string_name(op: OpHandle):
op.match_ops("linalg.matmul")
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]]
# CHECK-LABEL: TEST: test_match_ops_string_iface
@build_transform_script
def test_match_ops_string_iface(op: OpHandle):
op.match_ops("LinalgOp")
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
# CHECK-LABEL: TEST: test_match_ops_iface
@build_transform_script
def test_match_ops_iface(op: OpHandle):
op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
# CHECK-LABEL: TEST: test_match_ops_multiple
@build_transform_script
def test_match_ops_multiple(op: OpHandle):
op.match_ops([scf.ForOp, scf.ForallOp])
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
# CHECK-SAME: -> !transform.any_op
# CHECK-LABEL: TEST: test_match_ops_mixed
@build_transform_script
def test_match_ops_mixed(op: OpHandle):
op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
# CHECK-SAME: -> !transform.any_op
# CHECK-LABEL: TEST: test_print_message
@build_transform_script
def test_print_message(op: OpHandle):
op.print("message")
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"}
# CHECK-LABEL: TEST: test_print_plain
@build_transform_script
def test_print_plain(op: OpHandle):
op.print()
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
# CHECK-NEXT: transform.print %[[VAL_0]]
# CHECK-LABEL: TEST: test_sequence_region
@construct_and_print_in_module
def test_sequence_region():
# CHECK: transform.sequence failures(propagate) {
# CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
# CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
# CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
# CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
# CHECK: }
@sequence([], FailurePropagationMode.Propagate, [])
def basic(target: any_op_t()):
m = structured_match(any_op_t(), target, ops=["arith.addi"])
loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
loop_unroll(loop, 4)
# CHECK-LABEL: TEST: test_apply_patterns
@construct_and_print_in_module
def test_apply_patterns():
# CHECK: transform.sequence failures(propagate) {
# CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
# CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
# CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
# CHECK: apply_patterns to %[[VAL_2]] {
# CHECK: transform.apply_patterns.canonicalization
# CHECK: } : !pdl.operation
# CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
# CHECK: apply_cse to %[[VAL_3]] : !transform.any_op
# CHECK: }
@sequence([], FailurePropagationMode.Propagate, [])
def basic(variant_op: any_op_t()):
matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
@apply_patterns(top_func)
def pats():
apply_patterns_canonicalization()
top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
apply_cse(top_func)