| # RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false |
| |
| import gc |
| from mlir.ir import * |
| from mlir.dialects import func |
| |
| |
| def run(f): |
| print("\nTEST:", f.__name__) |
| f() |
| gc.collect() |
| assert Context._get_live_count() == 0 |
| return f |
| |
| |
| # CHECK-LABEL: TEST: testCapsuleConversions |
| @run |
| def testCapsuleConversions(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| value = Operation.create("custom.op1", results=[i32]).result |
| value_capsule = value._CAPIPtr |
| assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule) |
| value2 = Value._CAPICreate(value_capsule) |
| assert value2 == value |
| |
| |
| # CHECK-LABEL: TEST: testOpResultOwner |
| @run |
| def testOpResultOwner(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| op = Operation.create("custom.op1", results=[i32]) |
| assert op.result.owner == op |
| |
| |
| # CHECK-LABEL: TEST: testBlockArgOwner |
| @run |
| def testBlockArgOwner(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| module = Module.parse( |
| r""" |
| func.func @foo(%arg0: f32) { |
| return |
| }""", |
| ctx, |
| ) |
| func = module.body.operations[0] |
| block = func.regions[0].blocks[0] |
| assert block.arguments[0].owner == block |
| |
| |
| # CHECK-LABEL: TEST: testValueIsInstance |
| @run |
| def testValueIsInstance(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| module = Module.parse( |
| r""" |
| func.func @foo(%arg0: f32) { |
| %0 = "some_dialect.some_op"() : () -> f64 |
| return |
| }""", |
| ctx, |
| ) |
| func = module.body.operations[0] |
| assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0]) |
| assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0]) |
| |
| op = func.regions[0].blocks[0].operations[0] |
| assert not BlockArgument.isinstance(op.results[0]) |
| assert OpResult.isinstance(op.results[0]) |
| |
| |
| # CHECK-LABEL: TEST: testValueHash |
| @run |
| def testValueHash(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| module = Module.parse( |
| r""" |
| func.func @foo(%arg0: f32) -> f32 { |
| %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32 |
| return %0 : f32 |
| }""", |
| ctx, |
| ) |
| |
| [func] = module.body.operations |
| block = func.entry_block |
| op, ret = block.operations |
| assert hash(block.arguments[0]) == hash(op.operands[0]) |
| assert hash(op.result) == hash(ret.operands[0]) |
| |
| |
| # CHECK-LABEL: TEST: testValueUses |
| @run |
| def testValueUses(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| module = Module.create() |
| with InsertionPoint(module.body): |
| value = Operation.create("custom.op1", results=[i32]).results[0] |
| op1 = Operation.create("custom.op2", operands=[value]) |
| op2 = Operation.create("custom.op2", operands=[value]) |
| |
| # CHECK: Use owner: "custom.op2" |
| # CHECK: Use operand_number: 0 |
| # CHECK: Use owner: "custom.op2" |
| # CHECK: Use operand_number: 0 |
| for use in value.uses: |
| assert use.owner in [op1, op2] |
| print(f"Use owner: {use.owner}") |
| print(f"Use operand_number: {use.operand_number}") |
| |
| |
| # CHECK-LABEL: TEST: testValueReplaceAllUsesWith |
| @run |
| def testValueReplaceAllUsesWith(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| module = Module.create() |
| with InsertionPoint(module.body): |
| value = Operation.create("custom.op1", results=[i32]).results[0] |
| op1 = Operation.create("custom.op2", operands=[value]) |
| op2 = Operation.create("custom.op2", operands=[value]) |
| value2 = Operation.create("custom.op3", results=[i32]).results[0] |
| value.replace_all_uses_with(value2) |
| |
| assert len(list(value.uses)) == 0 |
| |
| # CHECK: Use owner: "custom.op2" |
| # CHECK: Use operand_number: 0 |
| # CHECK: Use owner: "custom.op2" |
| # CHECK: Use operand_number: 0 |
| for use in value2.uses: |
| assert use.owner in [op1, op2] |
| print(f"Use owner: {use.owner}") |
| print(f"Use operand_number: {use.operand_number}") |
| |
| |
| # CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept |
| @run |
| def testValueReplaceAllUsesWithExcept(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| module = Module.create() |
| with InsertionPoint(module.body): |
| value = Operation.create("custom.op1", results=[i32]).results[0] |
| op1 = Operation.create("custom.op1", operands=[value]) |
| op2 = Operation.create("custom.op2", operands=[value]) |
| value2 = Operation.create("custom.op3", results=[i32]).results[0] |
| value.replace_all_uses_except(value2, op1) |
| |
| assert len(list(value.uses)) == 1 |
| |
| # CHECK: Use owner: "custom.op2" |
| # CHECK: Use operand_number: 0 |
| for use in value2.uses: |
| assert use.owner in [op2] |
| print(f"Use owner: {use.owner}") |
| print(f"Use operand_number: {use.operand_number}") |
| |
| # CHECK: Use owner: "custom.op1" |
| # CHECK: Use operand_number: 0 |
| for use in value.uses: |
| assert use.owner in [op1] |
| print(f"Use owner: {use.owner}") |
| print(f"Use operand_number: {use.operand_number}") |
| |
| |
| # CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions |
| @run |
| def testValueReplaceAllUsesWithMultipleExceptions(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| module = Module.create() |
| with InsertionPoint(module.body): |
| value = Operation.create("custom.op1", results=[i32]).results[0] |
| op1 = Operation.create("custom.op1", operands=[value]) |
| op2 = Operation.create("custom.op2", operands=[value]) |
| op3 = Operation.create("custom.op3", operands=[value]) |
| value2 = Operation.create("custom.op4", results=[i32]).results[0] |
| |
| # Replace all uses of `value` with `value2`, except for `op1` and `op2`. |
| value.replace_all_uses_except(value2, [op1, op2]) |
| |
| # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`. |
| assert len(list(value.uses)) == 2 |
| assert len(list(value2.uses)) == 1 |
| |
| # CHECK: Use owner: "custom.op3" |
| # CHECK: Use operand_number: 0 |
| for use in value2.uses: |
| assert use.owner in [op3] |
| print(f"Use owner: {use.owner}") |
| print(f"Use operand_number: {use.operand_number}") |
| |
| # CHECK: Use owner: "custom.op2" |
| # CHECK: Use operand_number: 0 |
| # CHECK: Use owner: "custom.op1" |
| # CHECK: Use operand_number: 0 |
| for use in value.uses: |
| assert use.owner in [op1, op2] |
| print(f"Use owner: {use.owner}") |
| print(f"Use operand_number: {use.operand_number}") |
| |
| |
| # CHECK-LABEL: TEST: testValuePrintAsOperand |
| @run |
| def testValuePrintAsOperand(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| module = Module.create() |
| with InsertionPoint(module.body): |
| value = Operation.create("custom.op1", results=[i32]).results[0] |
| # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32) |
| print(value) |
| |
| value2 = Operation.create("custom.op2", results=[i32]).results[0] |
| # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32) |
| print(value2) |
| |
| topFn = func.FuncOp("test", ([i32, i32], [])) |
| entry_block = Block.create_at_start(topFn.operation.regions[0], [i32, i32]) |
| |
| with InsertionPoint(entry_block): |
| value3 = Operation.create("custom.op3", results=[i32]).results[0] |
| # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32) |
| print(value3) |
| value4 = Operation.create("custom.op4", results=[i32]).results[0] |
| # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32) |
| print(value4) |
| func.ReturnOp([]) |
| |
| # CHECK: %[[VAL1]] |
| print(value.get_name()) |
| # CHECK: %[[VAL2]] |
| print(value2.get_name()) |
| # CHECK: %[[VAL3]] |
| print(value3.get_name()) |
| # CHECK: %[[VAL4]] |
| print(value4.get_name()) |
| |
| print("With AsmState") |
| # CHECK-LABEL: With AsmState |
| state = AsmState(topFn.operation, use_local_scope=True) |
| # CHECK: %0 |
| print(value3.get_name(state=state)) |
| # CHECK: %1 |
| print(value4.get_name(state=state)) |
| |
| print("With use_local_scope") |
| # CHECK-LABEL: With use_local_scope |
| # CHECK: %0 |
| print(value3.get_name(use_local_scope=True)) |
| # CHECK: %1 |
| print(value4.get_name(use_local_scope=True)) |
| |
| # CHECK: %[[ARG0:.*]] |
| print(entry_block.arguments[0].get_name()) |
| # CHECK: %[[ARG1:.*]] |
| print(entry_block.arguments[1].get_name()) |
| |
| # CHECK: module { |
| # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32 |
| # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32 |
| # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) { |
| # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32 |
| # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32 |
| # CHECK: return |
| # CHECK: } |
| # CHECK: } |
| print(module) |
| |
| value2.owner.detach_from_parent() |
| # CHECK: %0 |
| print(value2.get_name()) |
| |
| |
| # CHECK-LABEL: TEST: testValueSetType |
| @run |
| def testValueSetType(): |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| i64 = IntegerType.get_signless(64) |
| module = Module.create() |
| with InsertionPoint(module.body): |
| value = Operation.create("custom.op1", results=[i32]).results[0] |
| # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32) |
| print(value) |
| |
| value.set_type(i64) |
| # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64) |
| print(value) |
| |
| # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64 |
| print(value.owner) |
| |
| |
| # CHECK-LABEL: TEST: testValueCasters |
| @run |
| def testValueCasters(): |
| class NOPResult(OpResult): |
| def __init__(self, v): |
| super().__init__(v) |
| |
| def __str__(self): |
| return super().__str__().replace(Value.__name__, NOPResult.__name__) |
| |
| class NOPValue(Value): |
| def __init__(self, v): |
| super().__init__(v) |
| |
| def __str__(self): |
| return super().__str__().replace(Value.__name__, NOPValue.__name__) |
| |
| class NOPBlockArg(BlockArgument): |
| def __init__(self, v): |
| super().__init__(v) |
| |
| def __str__(self): |
| return super().__str__().replace(Value.__name__, NOPBlockArg.__name__) |
| |
| @register_value_caster(IntegerType.static_typeid) |
| def cast_int(v) -> Value: |
| print("in caster", v.__class__.__name__) |
| if isinstance(v, OpResult): |
| return NOPResult(v) |
| if isinstance(v, BlockArgument): |
| return NOPBlockArg(v) |
| elif isinstance(v, Value): |
| return NOPValue(v) |
| |
| ctx = Context() |
| ctx.allow_unregistered_dialects = True |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| module = Module.create() |
| with InsertionPoint(module.body): |
| values = Operation.create("custom.op1", results=[i32, i32]).results |
| # CHECK: in caster OpResult |
| # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) |
| print("result", values[0].result_number, values[0]) |
| # CHECK: in caster OpResult |
| # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) |
| print("result", values[1].result_number, values[1]) |
| |
| # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) |
| print("results slice", values[:1][0].result_number, values[:1][0]) |
| |
| value0, value1 = values |
| # CHECK: in caster OpResult |
| # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) |
| print("result", value0.result_number, values[0]) |
| # CHECK: in caster OpResult |
| # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) |
| print("result", value1.result_number, values[1]) |
| |
| op1 = Operation.create("custom.op2", operands=[value0, value1]) |
| # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> () |
| print(op1) |
| |
| # CHECK: in caster Value |
| # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32)) |
| print("operand 0", op1.operands[0]) |
| # CHECK: in caster Value |
| # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32)) |
| print("operand 1", op1.operands[1]) |
| |
| # CHECK: in caster BlockArgument |
| # CHECK: in caster BlockArgument |
| @func.FuncOp.from_py_func(i32, i32) |
| def reduction(arg0, arg1): |
| # CHECK: as func arg 0 NOPBlockArg |
| print("as func arg", arg0.arg_number, arg0.__class__.__name__) |
| # CHECK: as func arg 1 NOPBlockArg |
| print("as func arg", arg1.arg_number, arg1.__class__.__name__) |
| |
| # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0) |
| print( |
| "args slice", |
| reduction.func_op.arguments[:1][0].arg_number, |
| reduction.func_op.arguments[:1][0], |
| ) |
| |
| try: |
| |
| @register_value_caster(IntegerType.static_typeid) |
| def dont_cast_int_shouldnt_register(v): |
| ... |
| |
| except RuntimeError as e: |
| # CHECK: Value caster is already registered: {{.*}}cast_int |
| print(e) |
| |
| @register_value_caster(IntegerType.static_typeid, replace=True) |
| def dont_cast_int(v) -> OpResult: |
| assert isinstance(v, OpResult) |
| print("don't cast", v.result_number, v) |
| return v |
| |
| with Location.unknown(ctx): |
| i32 = IntegerType.get_signless(32) |
| module = Module.create() |
| with InsertionPoint(module.body): |
| # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32) |
| new_value = Operation.create("custom.op1", results=[i32]).result |
| # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32) |
| print("result", new_value.result_number, new_value) |
| |
| # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32) |
| new_value = Operation.create("custom.op2", results=[i32]).results[0] |
| # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32) |
| print("result", new_value.result_number, new_value) |