| # RUN: %PYTHON %s | FileCheck %s |
| |
| import gc |
| import io |
| import itertools |
| from mlir.ir import * |
| |
| |
| def run(f): |
| print("\nTEST:", f.__name__) |
| f() |
| gc.collect() |
| assert Context._get_live_count() == 0 |
| return f |
| |
| |
| # CHECK-LABEL: TEST: testSymbolTableInsert |
| @run |
| def testSymbolTableInsert(): |
| with Context() as ctx: |
| ctx.allow_unregistered_dialects = True |
| m1 = Module.parse(""" |
| func private @foo() |
| func private @bar()""") |
| m2 = Module.parse(""" |
| func private @qux() |
| func private @foo() |
| "foo.bar"() : () -> ()""") |
| |
| symbol_table = SymbolTable(m1.operation) |
| |
| # CHECK: func private @foo |
| # CHECK: func private @bar |
| assert "foo" in symbol_table |
| print(symbol_table["foo"]) |
| assert "bar" in symbol_table |
| bar = symbol_table["bar"] |
| print(symbol_table["bar"]) |
| |
| assert "qux" not in symbol_table |
| |
| del symbol_table["bar"] |
| try: |
| symbol_table.erase(symbol_table["bar"]) |
| except KeyError: |
| pass |
| else: |
| assert False, "expected KeyError" |
| |
| # CHECK: module |
| # CHECK: func private @foo() |
| print(m1) |
| assert "bar" not in symbol_table |
| |
| try: |
| print(bar) |
| except RuntimeError as e: |
| if "the operation has been invalidated" not in str(e): |
| raise |
| else: |
| assert False, "expected RuntimeError due to invalidated operation" |
| |
| qux = m2.body.operations[0] |
| m1.body.append(qux) |
| symbol_table.insert(qux) |
| assert "qux" in symbol_table |
| |
| # Check that insertion actually renames this symbol in the symbol table. |
| foo2 = m2.body.operations[0] |
| m1.body.append(foo2) |
| updated_name = symbol_table.insert(foo2) |
| assert foo2.name.value != "foo" |
| assert foo2.name == updated_name |
| |
| # CHECK: module |
| # CHECK: func private @foo() |
| # CHECK: func private @qux() |
| # CHECK: func private @foo{{.*}} |
| print(m1) |
| |
| try: |
| symbol_table.insert(m2.body.operations[0]) |
| except ValueError as e: |
| if "Expected operation to have a symbol name" not in str(e): |
| raise |
| else: |
| assert False, "exepcted ValueError when adding a non-symbol" |
| |
| |
| # CHECK-LABEL: testSymbolTableRAUW |
| @run |
| def testSymbolTableRAUW(): |
| with Context() as ctx: |
| m = Module.parse(""" |
| func private @foo() { |
| call @bar() : () -> () |
| return |
| } |
| func private @bar() |
| """) |
| foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2] |
| SymbolTable.set_symbol_name(bar, "bam") |
| # Note that module.operation counts as a "nested symbol table" which won't |
| # be traversed into, so it is necessary to traverse its children. |
| SymbolTable.replace_all_symbol_uses("bar", "bam", foo) |
| # CHECK: call @bam() |
| # CHECK: func private @bam |
| print(m) |
| # CHECK: Foo symbol: "foo" |
| # CHECK: Bar symbol: "bam" |
| print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}") |
| print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}") |
| |
| |
| # CHECK-LABEL: testSymbolTableVisibility |
| @run |
| def testSymbolTableVisibility(): |
| with Context() as ctx: |
| m = Module.parse(""" |
| func private @foo() { |
| return |
| } |
| """) |
| foo = m.operation.regions[0].blocks[0].operations[0] |
| # CHECK: Existing visibility: "private" |
| print(f"Existing visibility: {SymbolTable.get_visibility(foo)}") |
| SymbolTable.set_visibility(foo, "public") |
| # CHECK: func public @foo |
| print(m) |
| |
| |
| # CHECK: testWalkSymbolTables |
| @run |
| def testWalkSymbolTables(): |
| with Context() as ctx: |
| m = Module.parse(""" |
| module @outer { |
| module @inner{ |
| } |
| } |
| """) |
| def callback(symbol_table_op, uses_visible): |
| print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}") |
| # CHECK: SYMBOL TABLE: True: module @inner |
| # CHECK: SYMBOL TABLE: True: module @outer |
| SymbolTable.walk_symbol_tables(m.operation, True, callback) |
| |
| # Make sure exceptions in the callback are handled. |
| def error_callback(symbol_table_op, uses_visible): |
| assert False, "Raised from python" |
| try: |
| SymbolTable.walk_symbol_tables(m.operation, True, error_callback) |
| except RuntimeError as e: |
| # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python |
| print(f"GOT EXCEPTION: {e}") |
| |