blob: 2e53aa64b999f8989e9f9b793d35e238a7b4ec51 [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
# CHECK-LABEL: TEST: test_insert_at_block_end
def test_insert_at_block_end():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint(entry_block)
ip.insert(Operation.create("custom.op2"))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_block_end)
# CHECK-LABEL: TEST: test_insert_before_operation
def test_insert_before_operation():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
"custom.op2"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint(entry_block.operations[1])
ip.insert(Operation.create("custom.op3"))
# CHECK: "custom.op1"
# CHECK: "custom.op3"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_before_operation)
# CHECK-LABEL: TEST: test_insert_at_block_begin
def test_insert_at_block_begin():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(r"""
func @foo() -> () {
"custom.op2"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(Operation.create("custom.op1"))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_block_begin)
# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
def test_insert_at_block_begin_empty():
# TODO: Write this test case when we can create such a situation.
pass
run(test_insert_at_block_begin_empty)
# CHECK-LABEL: TEST: test_insert_at_terminator
def test_insert_at_terminator():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
return
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint.at_block_terminator(entry_block)
ip.insert(Operation.create("custom.op2"))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_terminator)
# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
def test_insert_at_block_terminator_missing():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx:
module = Module.parse(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
try:
ip = InsertionPoint.at_block_terminator(entry_block)
except ValueError as e:
# CHECK: Block has no terminator
print(e)
else:
assert False, "Expected exception"
run(test_insert_at_block_terminator_missing)
# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
def test_insert_at_end_with_terminator_errors():
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
module = Module.parse(r"""
func @foo() -> () {
return
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
with InsertionPoint(entry_block):
try:
Operation.create("custom.op1", results=[], operands=[])
except IndexError as e:
# CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
print(f"ERROR: {e}")
run(test_insert_at_end_with_terminator_errors)
# CHECK-LABEL: TEST: test_insertion_point_context
def test_insertion_point_context():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
module = Module.parse(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
with InsertionPoint(entry_block):
Operation.create("custom.op2")
with InsertionPoint.at_block_begin(entry_block):
Operation.create("custom.opa")
Operation.create("custom.opb")
Operation.create("custom.op3")
# CHECK: "custom.opa"
# CHECK: "custom.opb"
# CHECK: "custom.op1"
# CHECK: "custom.op2"
# CHECK: "custom.op3"
module.operation.print()
run(test_insertion_point_context)