blob: e38155dcda973e8d1eb7eb2f57dc50ae731256dc [file] [log] [blame] [edit]
# RUN: %PYTHON %s 2>&1 | FileCheck %s
# REQUIRES: host-supports-jit
from mlir.ir import *
from mlir.dialects.ext import *
from mlir.rewrite import *
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.dialects import llvm, scf, func
from functools import partial
class BfDialect(Dialect, name="bf"):
pass
class PtrType(BfDialect.Type, name="ptr"):
pass
class NextOp(BfDialect.Operation, name="next"):
in_: Operand[PtrType]
out: Result[PtrType[()]] = result(infer_type=True)
class PrevOp(BfDialect.Operation, name="prev"):
in_: Operand[PtrType]
out: Result[PtrType[()]] = result(infer_type=True)
class IncOp(BfDialect.Operation, name="inc"):
in_: Operand[PtrType]
class DecOp(BfDialect.Operation, name="dec"):
in_: Operand[PtrType]
class InputOp(BfDialect.Operation, name="input"):
in_: Operand[PtrType]
class OutputOp(BfDialect.Operation, name="output"):
in_: Operand[PtrType]
class WhileOp(BfDialect.Operation, name="while"):
in_: Operand[PtrType]
out: Result[PtrType[()]] = result(infer_type=True)
body: Region
class YieldOp(BfDialect.Operation, name="yield", traits=[IsTerminatorTrait]):
in_: Operand[PtrType]
class MainOp(BfDialect.Operation, name="main"):
body: Region
def parse(code: str):
module = Module.create()
with InsertionPoint(module.body):
main = MainOp()
main.body.blocks.append()
current_val = main.body.blocks[0].add_argument(
PtrType.get(), Location.unknown()
)
ip = InsertionPoint(main.body.blocks[0])
for c in code:
with ip:
if c == ">":
current_val = NextOp(current_val).out
elif c == "<":
current_val = PrevOp(current_val).out
elif c == "+":
IncOp(current_val)
elif c == "-":
DecOp(current_val)
elif c == ".":
OutputOp(current_val)
elif c == ",":
InputOp(current_val)
elif c == "[":
loop = WhileOp(current_val)
loop.body.blocks.append()
current_val = loop.body.blocks[0].add_argument(
PtrType.get(), Location.unknown()
)
ip = InsertionPoint(loop.body.blocks[0])
elif c == "]":
YieldOp(current_val)
current_val = ip.block.owner.opview.out
ip = InsertionPoint.after(current_val.owner)
with ip:
YieldOp(current_val)
return module
def convert_bf_to_llvm(op, pass_):
patterns = RewritePatternSet()
ptr = llvm.PointerType.get()
i8 = IntegerType.get_signless(8)
i32 = IntegerType.get_signless(32)
type_converter = TypeConverter()
def convert_ptr(t):
return ptr if isinstance(t, PtrType) else None
type_converter.add_conversion(convert_ptr)
def convert_next(op, adaptor, converter, rewriter, offset=1):
with rewriter.ip:
gep = llvm.GEPOp(ptr, adaptor.in_, [], [offset], i8, [])
rewriter.replace_op(op, gep)
def convert_inc(op, adaptor, converter, rewriter, cst=1):
with rewriter.ip:
load = llvm.load(i8, adaptor.in_)
one = llvm.mlir_constant(IntegerAttr.get(i8, cst))
added = llvm.add(load, one, [])
store = llvm.StoreOp(added, adaptor.in_)
rewriter.replace_op(op, store)
def convert_main(op, adaptor, converter, rewriter):
with rewriter.ip:
fn = func.FuncOp("bf_main", FunctionType.get([ptr], [ptr]))
op.body.blocks[0].append_to(fn.body)
rewriter.convert_region_types(fn.body, converter)
rewriter.replace_op(op, fn)
def convert_yield(op, adaptor, converter, rewriter):
with rewriter.ip:
if isinstance(op.parent.opview, WhileOp | scf.WhileOp):
yield_ = scf.YieldOp([adaptor.in_])
else:
yield_ = func.ReturnOp([adaptor.in_])
rewriter.replace_op(op, yield_)
def convert_while(op, adaptor, converter, rewriter):
with rewriter.ip:
loop = scf.WhileOp([ptr], [adaptor.in_])
loop.before.blocks.append()
arg = loop.before.blocks[0].add_argument(ptr, Location.unknown())
with InsertionPoint(loop.before.blocks[0]):
c = llvm.load(i8, arg)
zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
cond = llvm.icmp(llvm.ICmpPredicate.ne, c, zero)
scf.ConditionOp(cond, [arg])
op.body.blocks[0].append_to(loop.after)
rewriter.convert_region_types(loop.after, converter)
rewriter.replace_op(op, loop)
def convert_output(op, adaptor, converter, rewriter):
with rewriter.ip:
val = llvm.load(i8, adaptor.in_)
call = func.CallOp([], "bf_output", [val])
rewriter.replace_op(op, call)
def convert_input(op, adaptor, converter, rewriter):
with rewriter.ip:
call = func.call([i8], "bf_input", [])
store = llvm.StoreOp(call, adaptor.in_)
rewriter.replace_op(op, store)
patterns.add_conversion(NextOp, convert_next, type_converter)
patterns.add_conversion(PrevOp, partial(convert_next, offset=-1), type_converter)
patterns.add_conversion(IncOp, convert_inc, type_converter)
patterns.add_conversion(DecOp, partial(convert_inc, cst=-1), type_converter)
patterns.add_conversion(MainOp, convert_main, type_converter)
patterns.add_conversion(YieldOp, convert_yield, type_converter)
patterns.add_conversion(WhileOp, convert_while, type_converter)
patterns.add_conversion(OutputOp, convert_output, type_converter)
patterns.add_conversion(InputOp, convert_input, type_converter)
target = ConversionTarget()
target.add_illegal_dialect(BfDialect)
apply_partial_conversion(op, target, patterns.freeze())
with InsertionPoint(op.opview.body):
func.FuncOp("putchar", FunctionType.get([i32], [i32]), visibility="private")
func.FuncOp("getchar", FunctionType.get([], [i32]), visibility="private")
output = func.FuncOp("bf_output", FunctionType.get([i8], []))
output.body.blocks.append()
arg = output.body.blocks[0].add_argument(i8, Location.unknown())
with InsertionPoint(output.body.blocks[0]):
sext = llvm.sext(i32, arg)
func.call([i32], "putchar", [sext])
func.ReturnOp([])
input = func.FuncOp("bf_input", FunctionType.get([], [i8]))
input.body.blocks.append()
with InsertionPoint(input.body.blocks[0]):
call = func.call([i32], "getchar", [])
trunc = llvm.trunc(i8, call, [])
func.ReturnOp([trunc])
init = func.FuncOp("bf_init", FunctionType.get([], []))
init.attributes["llvm.emit_c_interface"] = UnitAttr.get()
init.body.blocks.append()
with InsertionPoint(init.body.blocks[0]):
c1024 = llvm.mlir_constant(IntegerAttr.get(i32, 1024))
zero = llvm.mlir_constant(IntegerAttr.get(i8, 0))
p = llvm.alloca(ptr, c1024, i8)
llvm.intr_memset(p, zero, c1024, False)
func.call([ptr], "bf_main", [p])
func.ReturnOp([])
def execute(code):
module = parse(code)
assert module.operation.verify()
pm = PassManager()
pm.add(convert_bf_to_llvm)
pm.add("convert-scf-to-cf, convert-to-llvm")
pm.run(module.operation)
ee = ExecutionEngine(module)
ee.lookup("bf_init")(0)
def run(f):
print("TEST:", f.__name__)
f()
with Context(), Location.unknown():
BfDialect.load()
# CHECK: TEST: test_convert_bf_to_llvm
@run
def test_convert_bf_to_llvm():
module = parse("[-]")
assert module.operation.verify()
# CHECK: "bf.main"() ({
# CHECK: ^bb0(%arg0: !bf.ptr):
# CHECK: %0 = "bf.while"(%arg0) ({
# CHECK: ^bb0(%arg1: !bf.ptr):
# CHECK: "bf.dec"(%arg1) : (!bf.ptr) -> ()
# CHECK: "bf.yield"(%arg1) : (!bf.ptr) -> ()
# CHECK: }) : (!bf.ptr) -> !bf.ptr
# CHECK: "bf.yield"(%0) : (!bf.ptr) -> ()
# CHECK: }) : () -> ()
print(module)
pm = PassManager()
pm.add(convert_bf_to_llvm)
pm.run(module.operation)
# CHECK: func.func @bf_main(%arg0: !llvm.ptr) -> !llvm.ptr {
# CHECK: %0 = scf.while (%arg1 = %arg0) : (!llvm.ptr) -> !llvm.ptr {
# CHECK: %1 = llvm.load %arg1 : !llvm.ptr -> i8
# CHECK: %2 = llvm.mlir.constant(0 : i8) : i8
# CHECK: %3 = llvm.icmp "ne" %1, %2 : i8
# CHECK: scf.condition(%3) %arg1 : !llvm.ptr
# CHECK: } do {
# CHECK: ^bb0(%arg1: !llvm.ptr):
# CHECK: %1 = llvm.load %arg1 : !llvm.ptr -> i8
# CHECK: %2 = llvm.mlir.constant(-1 : i8) : i8
# CHECK: %3 = llvm.add %1, %2 : i8
# CHECK: llvm.store %3, %arg1 : i8, !llvm.ptr
# CHECK: scf.yield %arg1 : !llvm.ptr
# CHECK: }
# CHECK: return %0 : !llvm.ptr
# CHECK: }
print(module)
# CHECK: TEST: test_bf_e2e
@run
def test_bf_e2e():
# CHECK: Hello World!
execute(
"++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++."
)