blob: 380dc7ca3fb2851c8da49068d5c5835df7bf89ca [file] [log] [blame]
# RUN: %PYTHON %s 2>&1 | FileCheck %s
import gc, sys
from mlir.ir import *
from mlir.passmanager import *
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()
def run(f):
log("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
# Verify capsule interop.
# CHECK-LABEL: TEST: testCapsule
def testCapsule():
with Context():
pm = PassManager()
pm_capsule = pm._CAPIPtr
assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
pm._testing_release()
pm1 = PassManager._CAPICreate(pm_capsule)
assert pm1 is not None # And does not crash.
run(testCapsule)
# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSuccess
def testParseSuccess():
with Context():
# A first import is expected to fail because the pass isn't registered
# until we import mlir.transforms
try:
pm = PassManager.parse("builtin.module(builtin.func(print-op-stats))")
# TODO: this error should be propagate to Python but the C API does not help right now.
# CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline
except ValueError as e:
# CHECK: ValueError exception: invalid pass pipeline 'builtin.module(builtin.func(print-op-stats))'.
log("ValueError exception:", e)
else:
log("Exception not produced")
# This will register the pass and round-trip should be possible now.
import mlir.transforms
pm = PassManager.parse("builtin.module(builtin.func(print-op-stats))")
# CHECK: Roundtrip: builtin.module(builtin.func(print-op-stats))
log("Roundtrip: ", pm)
run(testParseSuccess)
# Verify failure on unregistered pass.
# CHECK-LABEL: TEST: testParseFail
def testParseFail():
with Context():
try:
pm = PassManager.parse("unknown-pass")
except ValueError as e:
# CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'.
log("ValueError exception:", e)
else:
log("Exception not produced")
run(testParseFail)
# Verify failure on incorrect level of nesting.
# CHECK-LABEL: TEST: testInvalidNesting
def testInvalidNesting():
with Context():
try:
pm = PassManager.parse("builtin.func(normalize-memrefs)")
except ValueError as e:
# CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'builtin.func', did you intend to nest?
# CHECK: ValueError exception: invalid pass pipeline 'builtin.func(normalize-memrefs)'.
log("ValueError exception:", e)
else:
log("Exception not produced")
run(testInvalidNesting)
# Verify that a pass manager can execute on IR
# CHECK-LABEL: TEST: testRun
def testRunPipeline():
with Context():
pm = PassManager.parse("print-op-stats")
module = Module.parse(r"""func @successfulParse() { return }""")
pm.run(module)
# CHECK: Operations encountered:
# CHECK: builtin.func , 1
# CHECK: builtin.module , 1
# CHECK: std.return , 1
run(testRunPipeline)