| # RUN: %PYTHON %s | FileCheck %s |
| |
| from typing import Callable |
| from mlir import ir |
| from mlir.dialects import scf, pdl |
| from mlir.dialects.transform import ( |
| structured, |
| get_parent_op, |
| apply_patterns_canonicalization, |
| apply_cse, |
| any_op_t, |
| ) |
| from mlir.dialects.transform import FailurePropagationMode |
| from mlir.dialects.transform.structured import structured_match |
| from mlir.dialects.transform.loop import loop_unroll |
| from mlir.dialects.transform.extras import ( |
| constant_param, |
| OpHandle, |
| insert_transform_script, |
| sequence, |
| apply_patterns, |
| ) |
| from mlir.extras import types as T |
| |
| |
| def construct_and_print_in_module(f): |
| print("\nTEST:", f.__name__) |
| with ir.Context(), ir.Location.unknown(): |
| module = ir.Module.create() |
| with ir.InsertionPoint(module.body): |
| f() |
| print(module) |
| return f |
| |
| |
| def build_transform_script(script: Callable[[OpHandle], None]): |
| print("\nTEST:", script.__name__) |
| with ir.Context(), ir.Location.unknown(): |
| module = ir.Module.create() |
| module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() |
| insert_transform_script(module.body, script=script, dump_script=True) |
| module.operation.verify() |
| |
| |
| def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]): |
| print("\nTEST:", script.__name__) |
| with ir.Context(), ir.Location.unknown(): |
| module = ir.Module.create() |
| module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() |
| insert_transform_script( |
| ir.InsertionPoint.at_block_begin(module.body), |
| script=script, |
| dump_script=True, |
| ) |
| module.operation.verify() |
| |
| |
| # CHECK-LABEL: TEST: test_build_script_at_insertion_point |
| @build_transform_script_at_insertion_point |
| def test_build_script_at_insertion_point(op: OpHandle): |
| pass |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: transform.yield |
| # CHECK-NEXT: } |
| |
| |
| # CHECK-LABEL: TEST: test_constant_param_int |
| @build_transform_script |
| def test_constant_param_int(_: OpHandle): |
| constant_param(ir.IntegerAttr.get(T.i32(), 42)) |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32 |
| # CHECK-SAME: !transform.param<i32> |
| |
| |
| # CHECK-LABEL: TEST: test_constant_param_py_int |
| @build_transform_script |
| def test_constant_param_py_int(_: OpHandle): |
| constant_param(42) |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64 |
| # CHECK-SAME: !transform.param<i64> |
| |
| |
| # CHECK-LABEL: TEST: test_constant_param_symbol_attr |
| @build_transform_script |
| def test_constant_param_symbol_attr(_: OpHandle): |
| constant_param(ir.SymbolRefAttr.get(["symbol"])) |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol |
| # CHECK-SAME: !transform.any_param |
| |
| |
| # CHECK-LABEL: TEST: test_constant_param_type |
| @build_transform_script |
| def test_constant_param_type(_: OpHandle): |
| constant_param(ir.TypeAttr.get(T.i32())) |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32 |
| # CHECK-SAME: !transform.any_param |
| |
| |
| # CHECK-LABEL: TEST: test_get_defining_op |
| @build_transform_script |
| def test_get_defining_op(op: OpHandle): |
| op.get_result().get_defining_op() |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0] |
| # CHECK-SAME: !transform.any_value |
| # CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]] |
| |
| |
| # CHECK-LABEL: TEST: test_get_result |
| @build_transform_script |
| def test_get_result(op: OpHandle): |
| op.get_result() |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0] |
| |
| |
| # CHECK-LABEL: TEST: test_match_ops_single |
| @build_transform_script |
| def test_match_ops_single(op: OpHandle): |
| op.match_ops(scf.ForOp) |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]} |
| # CHECK-SAME: in %[[VAL_0]] |
| # CHECK-SAME: -> !transform.op<"scf.for"> |
| |
| |
| # CHECK-LABEL: TEST: test_match_ops_string_name |
| @build_transform_script |
| def test_match_ops_string_name(op: OpHandle): |
| op.match_ops("linalg.matmul") |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match |
| # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]] |
| |
| |
| # CHECK-LABEL: TEST: test_match_ops_string_iface |
| @build_transform_script |
| def test_match_ops_string_iface(op: OpHandle): |
| op.match_ops("LinalgOp") |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match |
| # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] |
| |
| |
| # CHECK-LABEL: TEST: test_match_ops_iface |
| @build_transform_script |
| def test_match_ops_iface(op: OpHandle): |
| op.match_ops(structured.MatchInterfaceEnum.LinalgOp) |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match |
| # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] |
| |
| |
| # CHECK-LABEL: TEST: test_match_ops_multiple |
| @build_transform_script |
| def test_match_ops_multiple(op: OpHandle): |
| op.match_ops([scf.ForOp, scf.ForallOp]) |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match |
| # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]] |
| # CHECK-SAME: -> !transform.any_op |
| |
| |
| # CHECK-LABEL: TEST: test_match_ops_mixed |
| @build_transform_script |
| def test_match_ops_mixed(op: OpHandle): |
| op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp]) |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match |
| # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]] |
| # CHECK-SAME: -> !transform.any_op |
| |
| |
| # CHECK-LABEL: TEST: test_print_message |
| @build_transform_script |
| def test_print_message(op: OpHandle): |
| op.print("message") |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"} |
| |
| |
| # CHECK-LABEL: TEST: test_print_plain |
| @build_transform_script |
| def test_print_plain(op: OpHandle): |
| op.print() |
| # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) { |
| # CHECK-NEXT: transform.print %[[VAL_0]] |
| |
| |
| # CHECK-LABEL: TEST: test_sequence_region |
| @construct_and_print_in_module |
| def test_sequence_region(): |
| # CHECK: transform.sequence failures(propagate) { |
| # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): |
| # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op |
| # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation |
| # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation |
| # CHECK: } |
| @sequence([], FailurePropagationMode.Propagate, []) |
| def basic(target: any_op_t()): |
| m = structured_match(any_op_t(), target, ops=["arith.addi"]) |
| loop = get_parent_op(pdl.op_t(), m, op_name="scf.for") |
| loop_unroll(loop, 4) |
| |
| |
| # CHECK-LABEL: TEST: test_apply_patterns |
| @construct_and_print_in_module |
| def test_apply_patterns(): |
| # CHECK: transform.sequence failures(propagate) { |
| # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): |
| # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op |
| # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation |
| # CHECK: apply_patterns to %[[VAL_2]] { |
| # CHECK: transform.apply_patterns.canonicalization |
| # CHECK: } : !pdl.operation |
| # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op |
| # CHECK: apply_cse to %[[VAL_3]] : !transform.any_op |
| # CHECK: } |
| @sequence([], FailurePropagationMode.Propagate, []) |
| def basic(variant_op: any_op_t()): |
| matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"]) |
| top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func") |
| |
| @apply_patterns(top_func) |
| def pats(): |
| apply_patterns_canonicalization() |
| |
| top_func = structured_match(any_op_t(), variant_op, ops=["func.func"]) |
| apply_cse(top_func) |