| # RUN: %PYTHON %s 2>&1 | FileCheck %s |
| |
| from mlir.dialects import arith, func, pdl |
| from mlir.dialects.builtin import module |
| from mlir.ir import * |
| from mlir.rewrite import * |
| |
| |
| def construct_and_print_in_module(f): |
| print("\nTEST:", f.__name__) |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| module = f(module) |
| if module is not None: |
| print(module) |
| return f |
| |
| |
| def get_pdl_patterns(): |
| # Create a rewrite from add to mul. This will match |
| # - operation name is arith.addi |
| # - operands are index types. |
| # - there are two operands. |
| with Location.unknown(): |
| m = Module.create() |
| with InsertionPoint(m.body): |
| # Change all arith.addi with index types to arith.muli. |
| @pdl.pattern(benefit=1, sym_name="addi_to_mul") |
| def pat(): |
| # Match arith.addi with index types. |
| index_type = pdl.TypeOp(IndexType.get()) |
| operand0 = pdl.OperandOp(index_type) |
| operand1 = pdl.OperandOp(index_type) |
| op0 = pdl.OperationOp( |
| name="arith.addi", args=[operand0, operand1], types=[index_type] |
| ) |
| |
| # Replace the matched op with arith.muli. |
| @pdl.rewrite() |
| def rew(): |
| newOp = pdl.OperationOp( |
| name="arith.muli", args=[operand0, operand1], types=[index_type] |
| ) |
| pdl.ReplaceOp(op0, with_op=newOp) |
| |
| # Create a PDL module from module and freeze it. At this point the ownership |
| # of the module is transferred to the PDL module. This ownership transfer is |
| # not yet captured Python side/has sharp edges. So best to construct the |
| # module and PDL module in same scope. |
| # FIXME: This should be made more robust. |
| return PDLModule(m).freeze() |
| |
| |
| # CHECK-LABEL: TEST: test_add_to_mul |
| # CHECK: arith.muli |
| @construct_and_print_in_module |
| def test_add_to_mul(module_): |
| index_type = IndexType.get() |
| |
| # Create a test case. |
| @module(sym_name="ir") |
| def ir(): |
| @func.func(index_type, index_type) |
| def add_func(a, b): |
| return arith.addi(a, b) |
| |
| frozen = get_pdl_patterns() |
| # Could apply frozen pattern set multiple times. |
| apply_patterns_and_fold_greedily(module_, frozen) |
| return module_ |
| |
| |
| # CHECK-LABEL: TEST: test_add_to_mul_with_op |
| # CHECK: arith.muli |
| @construct_and_print_in_module |
| def test_add_to_mul_with_op(module_): |
| index_type = IndexType.get() |
| |
| # Create a test case. |
| @module(sym_name="ir") |
| def ir(): |
| @func.func(index_type, index_type) |
| def add_func(a, b): |
| return arith.addi(a, b) |
| |
| frozen = get_pdl_patterns() |
| apply_patterns_and_fold_greedily(module_.operation, frozen) |
| return module_ |
| |
| |
| # If we use arith.constant and arith.addi here, |
| # these C++-defined folding/canonicalization will be applied |
| # implicitly in the greedy pattern rewrite driver to |
| # make our Python-defined folding useless, |
| # so here we define a new dialect to workaround this. |
| def load_myint_dialect(): |
| from mlir.dialects import irdl |
| |
| m = Module.create() |
| with InsertionPoint(m.body): |
| myint = irdl.dialect("myint") |
| with InsertionPoint(myint.body): |
| constant = irdl.operation_("constant") |
| with InsertionPoint(constant.body): |
| iattr = irdl.base(base_name="#builtin.integer") |
| i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) |
| irdl.attributes_([iattr], ["value"]) |
| irdl.results_([i32], ["cst"], [irdl.Variadicity.single]) |
| add = irdl.operation_("add") |
| with InsertionPoint(add.body): |
| i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) |
| irdl.operands_( |
| [i32, i32], |
| ["lhs", "rhs"], |
| [irdl.Variadicity.single, irdl.Variadicity.single], |
| ) |
| irdl.results_([i32], ["res"], [irdl.Variadicity.single]) |
| |
| m.operation.verify() |
| irdl.load_dialects(m) |
| |
| |
| # This PDL pattern is to fold constant additions, |
| # including two patterns: |
| # 1. add(constant0, constant1) -> constant2 |
| # where constant2 = constant0 + constant1; |
| # 2. add(x, 0) or add(0, x) -> x. |
| def get_pdl_pattern_fold(): |
| m = Module.create() |
| i32 = IntegerType.get_signless(32) |
| with InsertionPoint(m.body): |
| |
| @pdl.pattern(benefit=1, sym_name="myint_add_fold") |
| def pat(): |
| t = pdl.TypeOp(i32) |
| a0 = pdl.AttributeOp() |
| a1 = pdl.AttributeOp() |
| c0 = pdl.OperationOp( |
| name="myint.constant", attributes={"value": a0}, types=[t] |
| ) |
| c1 = pdl.OperationOp( |
| name="myint.constant", attributes={"value": a1}, types=[t] |
| ) |
| v0 = pdl.ResultOp(c0, 0) |
| v1 = pdl.ResultOp(c1, 0) |
| op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) |
| |
| @pdl.rewrite() |
| def rew(): |
| sum = pdl.apply_native_rewrite( |
| [pdl.AttributeType.get()], "add_fold", [a0, a1] |
| ) |
| newOp = pdl.OperationOp( |
| name="myint.constant", attributes={"value": sum}, types=[t] |
| ) |
| pdl.ReplaceOp(op0, with_op=newOp) |
| |
| @pdl.pattern(benefit=1, sym_name="myint_add_zero_fold") |
| def pat(): |
| t = pdl.TypeOp(i32) |
| v0 = pdl.OperandOp() |
| v1 = pdl.OperandOp() |
| v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1]) |
| op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) |
| |
| @pdl.rewrite() |
| def rew(): |
| pdl.ReplaceOp(op0, with_values=[v]) |
| |
| def add_fold(rewriter, results, values): |
| a0, a1 = values |
| results.append(IntegerAttr.get(i32, a0.value + a1.value)) |
| |
| def is_zero(value): |
| op = value.owner |
| if isinstance(op, OpView): |
| return op.name == "myint.constant" and op.attributes["value"].value == 0 |
| return False |
| |
| # Check if either operand is a constant zero, |
| # and append the other operand to the results if so. |
| def has_zero(rewriter, results, values): |
| v0, v1 = values |
| if is_zero(v0): |
| results.append(v1) |
| return False |
| if is_zero(v1): |
| results.append(v0) |
| return False |
| return True |
| |
| pdl_module = PDLModule(m) |
| pdl_module.register_rewrite_function("add_fold", add_fold) |
| pdl_module.register_constraint_function("has_zero", has_zero) |
| return pdl_module.freeze() |
| |
| |
| # CHECK-LABEL: TEST: test_pdl_register_function |
| # CHECK: "myint.constant"() {value = 8 : i32} : () -> i32 |
| @construct_and_print_in_module |
| def test_pdl_register_function(module_): |
| load_myint_dialect() |
| |
| module_ = Module.parse( |
| """ |
| %c0 = "myint.constant"() { value = 2 }: () -> (i32) |
| %c1 = "myint.constant"() { value = 3 }: () -> (i32) |
| %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32) |
| "myint.add"(%x, %c1): (i32, i32) -> (i32) |
| """ |
| ) |
| |
| frozen = get_pdl_pattern_fold() |
| apply_patterns_and_fold_greedily(module_, frozen) |
| |
| return module_ |
| |
| |
| # CHECK-LABEL: TEST: test_pdl_register_function_constraint |
| # CHECK: return %arg0 : i32 |
| @construct_and_print_in_module |
| def test_pdl_register_function_constraint(module_): |
| load_myint_dialect() |
| |
| module_ = Module.parse( |
| """ |
| func.func @f(%x : i32) -> i32 { |
| %c0 = "myint.constant"() { value = 1 }: () -> (i32) |
| %c1 = "myint.constant"() { value = -1 }: () -> (i32) |
| %a = "myint.add"(%c0, %c1): (i32, i32) -> (i32) |
| %b = "myint.add"(%a, %x): (i32, i32) -> (i32) |
| %c = "myint.add"(%b, %a): (i32, i32) -> (i32) |
| func.return %c : i32 |
| } |
| """ |
| ) |
| |
| frozen = get_pdl_pattern_fold() |
| apply_patterns_and_fold_greedily(module_, frozen) |
| |
| return module_ |
| |
| |
| # This pattern is to expand constant to additions |
| # unless the constant is no more than 1, |
| # e.g. 3 -> 1 + 2 -> 1 + (1 + 1). |
| def get_pdl_pattern_expand(): |
| m = Module.create() |
| i32 = IntegerType.get_signless(32) |
| with InsertionPoint(m.body): |
| |
| @pdl.pattern(benefit=1, sym_name="myint_constant_expand") |
| def pat(): |
| t = pdl.TypeOp(i32) |
| cst = pdl.AttributeOp() |
| pdl.apply_native_constraint([], "is_one", [cst]) |
| op0 = pdl.OperationOp( |
| name="myint.constant", attributes={"value": cst}, types=[t] |
| ) |
| |
| @pdl.rewrite() |
| def rew(): |
| expanded = pdl.apply_native_rewrite( |
| [pdl.OperationType.get()], "expand", [cst] |
| ) |
| pdl.ReplaceOp(op0, with_op=expanded) |
| |
| def is_one(rewriter, results, values): |
| cst = values[0].value |
| return cst <= 1 |
| |
| def expand(rewriter, results, values): |
| cst = values[0].value |
| c1 = cst // 2 |
| c2 = cst - c1 |
| with rewriter.ip: |
| op1 = Operation.create( |
| "myint.constant", |
| results=[i32], |
| attributes={"value": IntegerAttr.get(i32, c1)}, |
| ) |
| op2 = Operation.create( |
| "myint.constant", |
| results=[i32], |
| attributes={"value": IntegerAttr.get(i32, c2)}, |
| ) |
| res = Operation.create( |
| "myint.add", results=[i32], operands=[op1.result, op2.result] |
| ) |
| results.append(res) |
| |
| pdl_module = PDLModule(m) |
| pdl_module.register_constraint_function("is_one", is_one) |
| pdl_module.register_rewrite_function("expand", expand) |
| return pdl_module.freeze() |
| |
| |
| # CHECK-LABEL: TEST: test_pdl_register_function_expand |
| # CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32 |
| # CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32 |
| # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32 |
| # CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32 |
| # CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32 |
| # CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32 |
| # CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32 |
| # CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32 |
| # CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32 |
| # CHECK: return %8 : i32 |
| @construct_and_print_in_module |
| def test_pdl_register_function_expand(module_): |
| load_myint_dialect() |
| |
| module_ = Module.parse( |
| """ |
| func.func @f() -> i32 { |
| %0 = "myint.constant"() { value = 5 }: () -> (i32) |
| return %0 : i32 |
| } |
| """ |
| ) |
| |
| frozen = get_pdl_pattern_expand() |
| apply_patterns_and_fold_greedily(module_, frozen) |
| |
| return module_ |