# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects.pdl import *


def constructAndPrintInModule(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            f()
        print(module)
    return f


# CHECK: module  {
# CHECK:   pdl.pattern @operations : benefit(1)  {
# CHECK:     %0 = attribute
# CHECK:     %1 = type
# CHECK:     %2 = operation  {"attr" = %0} -> (%1 : !pdl.type)
# CHECK:     %3 = result 0 of %2
# CHECK:     %4 = operand
# CHECK:     %5 = operation(%3, %4 : !pdl.value, !pdl.value)
# CHECK:     rewrite %5 with "rewriter"
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_operations():
    pattern = PatternOp(1, "operations")
    with InsertionPoint(pattern.body):
        attr = AttributeOp()
        ty = TypeOp()
        op0 = OperationOp(attributes={"attr": attr}, types=[ty])
        op0_result = ResultOp(op0, 0)
        input = OperandOp()
        root = OperationOp(args=[op0_result, input])
        RewriteOp(root, "rewriter")


# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_with_args : benefit(1)  {
# CHECK:     %0 = operand
# CHECK:     %1 = operation(%0 : !pdl.value)
# CHECK:     rewrite %1 with "rewriter"(%0 : !pdl.value)
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_with_args():
    pattern = PatternOp(1, "rewrite_with_args")
    with InsertionPoint(pattern.body):
        input = OperandOp()
        root = OperationOp(args=[input])
        RewriteOp(root, "rewriter", args=[input])


# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_multi_root_optimal : benefit(1)  {
# CHECK:     %0 = operand
# CHECK:     %1 = operand
# CHECK:     %2 = type
# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
# CHECK:     %4 = result 0 of %3
# CHECK:     %5 = operation(%4 : !pdl.value)
# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
# CHECK:     %7 = result 0 of %6
# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
# CHECK:     rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation)
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_multi_root_optimal():
    pattern = PatternOp(1, "rewrite_multi_root_optimal")
    with InsertionPoint(pattern.body):
        input1 = OperandOp()
        input2 = OperandOp()
        ty = TypeOp()
        op1 = OperationOp(args=[input1], types=[ty])
        val1 = ResultOp(op1, 0)
        root1 = OperationOp(args=[val1])
        op2 = OperationOp(args=[input2], types=[ty])
        val2 = ResultOp(op2, 0)
        root2 = OperationOp(args=[val1, val2])
        RewriteOp(name="rewriter", args=[root1, root2])


# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_multi_root_forced : benefit(1)  {
# CHECK:     %0 = operand
# CHECK:     %1 = operand
# CHECK:     %2 = type
# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
# CHECK:     %4 = result 0 of %3
# CHECK:     %5 = operation(%4 : !pdl.value)
# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
# CHECK:     %7 = result 0 of %6
# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
# CHECK:     rewrite %5 with "rewriter"(%8 : !pdl.operation)
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_multi_root_forced():
    pattern = PatternOp(1, "rewrite_multi_root_forced")
    with InsertionPoint(pattern.body):
        input1 = OperandOp()
        input2 = OperandOp()
        ty = TypeOp()
        op1 = OperationOp(args=[input1], types=[ty])
        val1 = ResultOp(op1, 0)
        root1 = OperationOp(args=[val1])
        op2 = OperationOp(args=[input2], types=[ty])
        val2 = ResultOp(op2, 0)
        root2 = OperationOp(args=[val1, val2])
        RewriteOp(root1, name="rewriter", args=[root2])


# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_add_body : benefit(1)  {
# CHECK:     %0 = type : i32
# CHECK:     %1 = type
# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK:     rewrite %2  {
# CHECK:       %3 = type
# CHECK:       %4 = operation "foo.op"  -> (%0, %3 : !pdl.type, !pdl.type)
# CHECK:       replace %2 with %4
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_add_body():
    pattern = PatternOp(1, "rewrite_add_body")
    with InsertionPoint(pattern.body):
        ty1 = TypeOp(IntegerType.get_signless(32))
        ty2 = TypeOp()
        root = OperationOp(types=[ty1, ty2])
        rewrite = RewriteOp(root)
        with InsertionPoint(rewrite.add_body()):
            ty3 = TypeOp()
            newOp = OperationOp(name="foo.op", types=[ty1, ty3])
            ReplaceOp(root, with_op=newOp)


# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_type : benefit(1)  {
# CHECK:     %0 = type : i32
# CHECK:     %1 = type
# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK:     rewrite %2  {
# CHECK:       %3 = operation "foo.op"  -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_type():
    pattern = PatternOp(1, "rewrite_type")
    with InsertionPoint(pattern.body):
        ty1 = TypeOp(IntegerType.get_signless(32))
        ty2 = TypeOp()
        root = OperationOp(types=[ty1, ty2])
        rewrite = RewriteOp(root)
        with InsertionPoint(rewrite.add_body()):
            newOp = OperationOp(name="foo.op", types=[ty1, ty2])


# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_types : benefit(1)  {
# CHECK:     %0 = types
# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
# CHECK:     rewrite %1  {
# CHECK:       %2 = types : [i32, i64]
# CHECK:       %3 = operation "foo.op"  -> (%0, %2 : !pdl.range<type>, !pdl.range<type>)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_types():
    pattern = PatternOp(1, "rewrite_types")
    with InsertionPoint(pattern.body):
        types = TypesOp()
        root = OperationOp(types=[types])
        rewrite = RewriteOp(root)
        with InsertionPoint(rewrite.add_body()):
            otherTypes = TypesOp(
                [IntegerType.get_signless(32), IntegerType.get_signless(64)]
            )
            newOp = OperationOp(name="foo.op", types=[types, otherTypes])


# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_operands : benefit(1)  {
# CHECK:     %0 = types
# CHECK:     %1 = operands : %0
# CHECK:     %2 = operation(%1 : !pdl.range<value>)
# CHECK:     rewrite %2  {
# CHECK:       %3 = operation "foo.op"  -> (%0 : !pdl.range<type>)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_operands():
    pattern = PatternOp(1, "rewrite_operands")
    with InsertionPoint(pattern.body):
        types = TypesOp()
        operands = OperandsOp(types)
        root = OperationOp(args=[operands])
        rewrite = RewriteOp(root)
        with InsertionPoint(rewrite.add_body()):
            newOp = OperationOp(name="foo.op", types=[types])


# CHECK: module  {
# CHECK:   pdl.pattern @native_rewrite : benefit(1)  {
# CHECK:     %0 = operation
# CHECK:     rewrite %0  {
# CHECK:       apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_native_rewrite():
    pattern = PatternOp(1, "native_rewrite")
    with InsertionPoint(pattern.body):
        root = OperationOp()
        rewrite = RewriteOp(root)
        with InsertionPoint(rewrite.add_body()):
            ApplyNativeRewriteOp([], "NativeRewrite", args=[root])


# CHECK: module  {
# CHECK:   pdl.pattern @attribute_with_value : benefit(1)  {
# CHECK:     %0 = operation
# CHECK:     rewrite %0  {
# CHECK:       %1 = attribute = "value"
# CHECK:       apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_attribute_with_value():
    pattern = PatternOp(1, "attribute_with_value")
    with InsertionPoint(pattern.body):
        root = OperationOp()
        rewrite = RewriteOp(root)
        with InsertionPoint(rewrite.add_body()):
            attr = AttributeOp(value=Attribute.parse('"value"'))
            ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])


# CHECK: module  {
# CHECK:   pdl.pattern @erase : benefit(1)  {
# CHECK:     %0 = operation
# CHECK:     rewrite %0  {
# CHECK:       erase %0
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_erase():
    pattern = PatternOp(1, "erase")
    with InsertionPoint(pattern.body):
        root = OperationOp()
        rewrite = RewriteOp(root)
        with InsertionPoint(rewrite.add_body()):
            EraseOp(root)


# CHECK: module  {
# CHECK:   pdl.pattern @operation_results : benefit(1)  {
# CHECK:     %0 = types
# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
# CHECK:     %2 = results of %1
# CHECK:     %3 = operation(%2 : !pdl.range<value>)
# CHECK:     rewrite %3 with "rewriter"
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_operation_results():
    valueRange = RangeType.get(ValueType.get())
    pattern = PatternOp(1, "operation_results")
    with InsertionPoint(pattern.body):
        types = TypesOp()
        inputOp = OperationOp(types=[types])
        results = ResultsOp(valueRange, inputOp)
        root = OperationOp(args=[results])
        RewriteOp(root, name="rewriter")


# CHECK: module  {
# CHECK:   pdl.pattern : benefit(1)  {
# CHECK:     %0 = type
# CHECK:     apply_native_constraint "typeConstraint"(%0 : !pdl.type)
# CHECK:     %1 = operation  -> (%0 : !pdl.type)
# CHECK:     rewrite %1 with "rewrite"
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_apply_native_constraint():
    pattern = PatternOp(1)
    with InsertionPoint(pattern.body):
        resultType = TypeOp()
        ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
        root = OperationOp(types=[resultType])
        RewriteOp(root, name="rewrite")
