| # RUN: %PYTHON %s 2>&1 | FileCheck %s |
| |
| import gc, sys |
| from mlir.ir import * |
| from mlir.passmanager import * |
| from mlir.dialects.func import FuncOp |
| from mlir.dialects.builtin import ModuleOp |
| |
| |
| # Log everything to stderr and flush so that we have a unified stream to match |
| # errors/info emitted by MLIR to stderr. |
| def log(*args): |
| print(*args, file=sys.stderr) |
| sys.stderr.flush() |
| |
| |
| def run(f): |
| log("\nTEST:", f.__name__) |
| f() |
| gc.collect() |
| assert Context._get_live_count() == 0 |
| |
| |
| # Verify capsule interop. |
| # CHECK-LABEL: TEST: testCapsule |
| def testCapsule(): |
| with Context(): |
| pm = PassManager() |
| pm_capsule = pm._CAPIPtr |
| assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule) |
| pm._testing_release() |
| pm1 = PassManager._CAPICreate(pm_capsule) |
| assert pm1 is not None # And does not crash. |
| |
| |
| run(testCapsule) |
| |
| |
| # CHECK-LABEL: TEST: testConstruct |
| @run |
| def testConstruct(): |
| with Context(): |
| # CHECK: pm1: 'any()' |
| # CHECK: pm2: 'builtin.module()' |
| pm1 = PassManager() |
| pm2 = PassManager("builtin.module") |
| log(f"pm1: '{pm1}'") |
| log(f"pm2: '{pm2}'") |
| |
| |
| # Verify successful round-trip. |
| # CHECK-LABEL: TEST: testParseSuccess |
| def testParseSuccess(): |
| with Context(): |
| # An unregistered pass should not parse. |
| try: |
| pm = PassManager.parse( |
| "builtin.module(func.func(not-existing-pass{json=false}))" |
| ) |
| except ValueError as e: |
| # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass |
| log("ValueError exception:", e) |
| else: |
| log("Exception not produced") |
| |
| # A registered pass should parse successfully. |
| pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") |
| # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) |
| log("Roundtrip: ", pm) |
| |
| |
| run(testParseSuccess) |
| |
| |
| # Verify successful round-trip. |
| # CHECK-LABEL: TEST: testParseSpacedPipeline |
| def testParseSpacedPipeline(): |
| with Context(): |
| # A registered pass should parse successfully even if has extras spaces for readability |
| pm = PassManager.parse( |
| """builtin.module( |
| func.func( print-op-stats{ json=false } ) |
| )""" |
| ) |
| # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) |
| log("Roundtrip: ", pm) |
| |
| |
| run(testParseSpacedPipeline) |
| |
| |
| # Verify failure on unregistered pass. |
| # CHECK-LABEL: TEST: testParseFail |
| def testParseFail(): |
| with Context(): |
| try: |
| pm = PassManager.parse("any(unknown-pass)") |
| except ValueError as e: |
| # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error: |
| # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline |
| # CHECK: unknown-pass |
| # CHECK: ^ |
| log("ValueError exception:", e) |
| else: |
| log("Exception not produced") |
| |
| |
| run(testParseFail) |
| |
| |
| # Check that adding to a pass manager works |
| # CHECK-LABEL: TEST: testAdd |
| @run |
| def testAdd(): |
| pm = PassManager("any", Context()) |
| # CHECK: pm: 'any()' |
| log(f"pm: '{pm}'") |
| # CHECK: pm: 'any(cse)' |
| pm.add("cse") |
| log(f"pm: '{pm}'") |
| # CHECK: pm: 'any(cse,cse)' |
| pm.add("cse") |
| log(f"pm: '{pm}'") |
| |
| |
| # Verify failure on incorrect level of nesting. |
| # CHECK-LABEL: TEST: testInvalidNesting |
| def testInvalidNesting(): |
| with Context(): |
| try: |
| pm = PassManager.parse("func.func(normalize-memrefs)") |
| except ValueError as e: |
| # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? |
| log("ValueError exception:", e) |
| else: |
| log("Exception not produced") |
| |
| |
| run(testInvalidNesting) |
| |
| |
| # Verify that a pass manager can execute on IR |
| # CHECK-LABEL: TEST: testRunPipeline |
| def testRunPipeline(): |
| with Context(): |
| pm = PassManager.parse("any(print-op-stats{json=false})") |
| func = FuncOp.parse(r"""func.func @successfulParse() { return }""") |
| pm.run(func) |
| |
| |
| # CHECK: Operations encountered: |
| # CHECK: func.func , 1 |
| # CHECK: func.return , 1 |
| run(testRunPipeline) |
| |
| |
| # CHECK-LABEL: TEST: testRunPipelineError |
| @run |
| def testRunPipelineError(): |
| with Context() as ctx: |
| ctx.allow_unregistered_dialects = True |
| op = Operation.parse('"test.op"() : () -> ()') |
| pm = PassManager.parse("any(cse)") |
| try: |
| pm.run(op) |
| except MLIRError as e: |
| # CHECK: Exception: < |
| # CHECK: Failure while executing pass pipeline: |
| # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation |
| # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () |
| # CHECK: > |
| log(f"Exception: <{e}>") |
| |
| |
| # CHECK-LABEL: TEST: testPostPassOpInvalidation |
| @run |
| def testPostPassOpInvalidation(): |
| with Context() as ctx: |
| log_op_count = lambda: log("live ops:", ctx._get_live_operation_count()) |
| |
| # CHECK: invalidate_ops=False |
| log("invalidate_ops=False") |
| |
| # CHECK: live ops: 0 |
| log_op_count() |
| |
| module = ModuleOp.parse( |
| """ |
| module { |
| arith.constant 10 |
| func.func @foo() { |
| arith.constant 10 |
| return |
| } |
| } |
| """ |
| ) |
| |
| # CHECK: live ops: 1 |
| log_op_count() |
| |
| outer_const_op = module.body.operations[0] |
| # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64 |
| log(outer_const_op) |
| |
| func_op = module.body.operations[1] |
| # CHECK: func.func @[[FOO:.*]]() { |
| # CHECK: %[[VAL1:.*]] = arith.constant 10 : i64 |
| # CHECK: return |
| # CHECK: } |
| log(func_op) |
| |
| inner_const_op = func_op.body.blocks[0].operations[0] |
| # CHECK: %[[VAL1]] = arith.constant 10 : i64 |
| log(inner_const_op) |
| |
| # CHECK: live ops: 4 |
| log_op_count() |
| |
| PassManager.parse("builtin.module(canonicalize)").run( |
| module, invalidate_ops=False |
| ) |
| # CHECK: func.func @foo() { |
| # CHECK: return |
| # CHECK: } |
| log(func_op) |
| |
| # CHECK: func.func @foo() { |
| # CHECK: return |
| # CHECK: } |
| log(module) |
| |
| # CHECK: invalidate_ops=True |
| log("invalidate_ops=True") |
| |
| # CHECK: live ops: 4 |
| log_op_count() |
| |
| module = ModuleOp.parse( |
| """ |
| module { |
| arith.constant 10 |
| func.func @foo() { |
| arith.constant 10 |
| return |
| } |
| } |
| """ |
| ) |
| outer_const_op = module.body.operations[0] |
| func_op = module.body.operations[1] |
| inner_const_op = func_op.body.blocks[0].operations[0] |
| |
| # CHECK: live ops: 4 |
| log_op_count() |
| |
| PassManager.parse("builtin.module(canonicalize)").run(module) |
| |
| # CHECK: live ops: 1 |
| log_op_count() |
| |
| try: |
| log(func_op) |
| except RuntimeError as e: |
| # CHECK: the operation has been invalidated |
| log(e) |
| |
| try: |
| log(outer_const_op) |
| except RuntimeError as e: |
| # CHECK: the operation has been invalidated |
| log(e) |
| |
| try: |
| log(inner_const_op) |
| except RuntimeError as e: |
| # CHECK: the operation has been invalidated |
| log(e) |
| |
| # CHECK: func.func @foo() { |
| # CHECK: return |
| # CHECK: } |
| log(module) |
| |
| |
| # CHECK-LABEL: TEST: testPrintIrAfterAll |
| @run |
| def testPrintIrAfterAll(): |
| with Context() as ctx: |
| module = ModuleOp.parse( |
| """ |
| module { |
| func.func @main() { |
| %0 = arith.constant 10 |
| return |
| } |
| } |
| """ |
| ) |
| pm = PassManager.parse("builtin.module(canonicalize)") |
| ctx.enable_multithreading(False) |
| pm.enable_ir_printing() |
| # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) ('builtin.module' operation) //----- // |
| # CHECK: module { |
| # CHECK: func.func @main() { |
| # CHECK: %[[C10:.*]] = arith.constant 10 : i64 |
| # CHECK: return |
| # CHECK: } |
| # CHECK: } |
| # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) ('builtin.module' operation) //----- // |
| # CHECK: module { |
| # CHECK: func.func @main() { |
| # CHECK: return |
| # CHECK: } |
| # CHECK: } |
| pm.run(module) |