# RUN: %PYTHON %s | FileCheck %s

import functools
from typing import Callable

from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import structured
from mlir.dialects.transform import pdl as transform_pdl
from mlir.dialects.transform.extras import constant_param


def run(f):
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            print("\nTEST:", f.__name__)
            f()
        module.operation.verify()
        print(module)
    return f


def create_sequence(func: Callable) -> Callable:
    @functools.wraps(func)
    def decorated() -> None:
        sequence = transform.SequenceOp(
            transform.FailurePropagationMode.Propagate,
            [],
            transform.AnyOpType.get(),
        )
        with InsertionPoint(sequence.body):
            func(sequence.bodyTarget)
            transform.YieldOp()

    return decorated


@run
@create_sequence
def testBufferizeToAllocationOpCompact(target):
    structured.BufferizeToAllocationOp(target)
    # CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact
    # CHECK: transform.sequence
    # CHECK: transform.structured.bufferize_to_allocation


@run
@create_sequence
def testBufferizeToAllocationOpArgs(target):
    structured.BufferizeToAllocationOp(
        target,
        memory_space=3,
        memcpy_op="memref.copy",
        alloc_op="memref.alloca",
        bufferize_destination_only=True,
    )
    # CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs
    # CHECK: transform.sequence
    # CHECK: transform.structured.bufferize_to_allocation
    # CHECK-SAME: alloc_op = "memref.alloca"
    # CHECK-SAME: bufferize_destination_only
    # CHECK-SAME: memcpy_op = "memref.copy"
    # CHECK-SAME: memory_space = 3


@run
@create_sequence
def testDecompose(target):
    structured.DecomposeOp(target)
    # CHECK-LABEL: TEST: testDecompose
    # CHECK: transform.sequence
    # CHECK: transform.structured.decompose


@run
@create_sequence
def testFuseIntoContainingOpTypes(target):
    fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
    containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
    structured.FuseIntoContainingOp(
        transform.OperationType.get("test.dummy"),
        transform.OperationType.get("test.dummy"),
        fused,
        containing,
    )
    # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
    # CHECK: = transform.structured.fuse_into_containing_op
    # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">)


@run
@create_sequence
def testFuseIntoContainingOpCompact(target):
    fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
    containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
    structured.FuseIntoContainingOp(fused, containing)
    # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
    # CHECK: = transform.structured.fuse_into_containing_op
    # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)


@run
@create_sequence
def testFuseOpCompact(target):
    structured.FuseOp(
        target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True
    )
    # CHECK-LABEL: TEST: testFuseOpCompact
    # CHECK: transform.sequence
    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
    # CHECK-SAME: interchange [0, 1] apply_cleanup = true
    # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)


@run
@create_sequence
def testFuseOpNoArg(target):
    structured.FuseOp(target)
    # CHECK-LABEL: TEST: testFuseOpNoArg
    # CHECK: transform.sequence
    # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
    # CHECK-SAME: (!transform.any_op) -> !transform.any_op


@run
@create_sequence
def testFuseOpAttributes(target):
    attr = DenseI64ArrayAttr.get([4, 8])
    ichange = DenseI64ArrayAttr.get([0, 1])
    structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
    # CHECK-LABEL: TEST: testFuseOpAttributes
    # CHECK: transform.sequence
    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
    # CHECK-SAME: interchange [0, 1]
    # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)


@run
@create_sequence
def testGeneralize(target):
    structured.GeneralizeOp(target)
    # CHECK-LABEL: TEST: testGeneralize
    # CHECK: transform.sequence
    # CHECK: transform.structured.generalize


@run
@create_sequence
def testInterchange(target):
    structured.InterchangeOp(target, iterator_interchange=[1, 0])
    # CHECK-LABEL: TEST: testInterchange
    # CHECK: transform.sequence
    # CHECK: transform.structured.interchange
    # CHECK: iterator_interchange = [1, 0]


@run
@create_sequence
def testMapCopyToThreadsOpCompact(target):
    structured.MapCopyToThreadsOp(
        target, total_num_threads=32, desired_bit_alignment=128
    )
    # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
    # CHECK: = transform.structured.gpu.map_copy_to_threads
    # CHECK-SAME: total_num_threads = 32
    # CHECK-SAME: desired_bit_alignment = 128
    # CHECK-SAME:  (!transform.any_op) -> (!transform.any_op, !transform.any_op)


@run
@create_sequence
def testMapCopyToThreadsOpTypes(target):
    structured.MapCopyToThreadsOp(
        transform.OperationType.get("test.opA"),
        transform.OperationType.get("test.opB"),
        target,
        total_num_threads=32,
        desired_bit_alignment=128,
    )
    # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
    # CHECK: = transform.structured.gpu.map_copy_to_threads
    # CHECK-SAME: total_num_threads = 32
    # CHECK-SAME: desired_bit_alignment = 128
    # CHECK-SAME:  (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)


@run
@create_sequence
def testMatchOpNamesString(target):
    structured.MatchOp.match_op_names(target, "test.dummy")
    # CHECK-LABEL: TEST: testMatchOpNamesString
    # CHECK: transform.structured.match ops
    # CHECK-SAME: ["test.dummy"]
    # CHECK-SAME: (!transform.any_op) -> !transform.any_op


@run
@create_sequence
def testMatchOpNamesList(target):
    structured.MatchOp.match_op_names(target, ["test.dummy"])
    # CHECK-LABEL: TEST: testMatchOpNamesList
    # CHECK: transform.structured.match ops
    # CHECK-SAME: ["test.dummy"]
    # CHECK-SAME: (!transform.any_op) -> !transform.any_op


@run
@create_sequence
def testVectorizeNoArgs(target):
    structured.VectorizeOp(target)
    # CHECK-LABEL: TEST: testVectorizeNoArgs
    # CHECK: transform.sequence
    # CHECK: transform.structured.vectorize
    # CHECK-NOT:     vector_sizes


@run
@create_sequence
def testVectorizeStatic(target):
    structured.VectorizeOp(target, [16, 4])
    # CHECK-LABEL: TEST: testVectorizeStatic
    # CHECK: transform.sequence
    # CHECK: transform.structured.vectorize
    # CHECK-SAME:     vector_sizes [16, 4]


@run
@create_sequence
def testVectorizeArray(target):
    sizes = Attribute.parse("[16, 4]")
    structured.VectorizeOp(target, sizes)
    # CHECK-LABEL: TEST: testVectorizeArray
    # CHECK: transform.sequence
    # CHECK: transform.structured.vectorize
    # CHECK-SAME:     vector_sizes [16, 4]


@run
@create_sequence
def testVectorizeMixed(target):
    sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
    sz2 = Attribute.parse("4")
    structured.VectorizeOp(target, [sz1, sz2])
    # CHECK-LABEL: TEST: testVectorizeMixed
    # CHECK: transform.sequence
    # CHECK: %[[V0:.*]] = transform.structured.match
    # CHECK: transform.structured.vectorize
    # CHECK-SAME:     vector_sizes [%[[V0]], 4]


@run
@create_sequence
def testVectorizeEmpty(target):
    structured.VectorizeOp(target, [])
    # CHECK-LABEL: TEST: testVectorizeEmpty
    # CHECK: transform.sequence
    # CHECK: transform.structured.vectorize
    # CHECK-NOT:     vector_sizes


@run
@create_sequence
def testVectorizeScalable(target):
    sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
    sz2 = Attribute.parse("4")
    structured.VectorizeOp(target, [16, [sz1], [sz2], [8]])
    # CHECK-LABEL: TEST: testVectorizeScalable
    # CHECK: transform.sequence
    # CHECK-DAG: %[[V0:.*]] = transform.structured.match
    # CHECK-DAG: transform.structured.vectorize
    # CHECK-SAME:     vector_sizes [16, [%[[V0]]], [4], [8]]


@run
@create_sequence
def testVectorizeArgs(target):
    structured.VectorizeOp(target, [16, 4], vectorize_nd_extract=True)
    # CHECK-LABEL: TEST: testVectorizeArgs
    # CHECK: transform.sequence
    # CHECK: transform.structured.vectorize
    # CHECK-SAME: vectorize_nd_extract


@run
@create_sequence
def testMatchOpNamesTyped(target):
    structured.MatchOp.match_op_names(
        transform.OperationType.get("test.dummy"),
        target,
        ["test.dummy"],
    )
    # CHECK-LABEL: TEST: testMatchOpNamesTyped
    # CHECK: transform.structured.match ops
    # CHECK-SAME: ["test.dummy"]
    # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">


@run
@create_sequence
def testMultitileSizesCompact(target):
    structured.MultiTileSizesOp(
        transform.AnyOpType.get(), target, dimension=1, target_size=42
    )
    # CHECK-LABEL: TEST: testMultitileSizes
    # CHECK: transform.sequence
    # CHECK-NOT: divisor
    # CHECK: transform.structured.multitile_sizes
    # CHECK-NOT: divisor
    # CHECK-DAG: dimension = 1
    # CHECK-NOT: divisor
    # CHECK-DAG: target_size = 42
    # CHECK-NOT: divisor


@run
@create_sequence
def testMultitileSizesAllArgs(target):
    structured.MultiTileSizesOp(
        transform.AnyOpType.get(),
        target,
        dimension=1,
        target_size=42,
        divisor=2,
    )
    # CHECK-LABEL: TEST: testMultitileSizes
    # CHECK: transform.sequence
    # CHECK: transform.structured.multitile_sizes
    # CHECK-DAG: dimension = 1
    # CHECK-DAG: divisor = 2
    # CHECK-DAG: target_size = 42


@run
@create_sequence
def testPadOpNoArgs(target):
    structured.PadOp(target)
    # CHECK-LABEL: TEST: testPadOpNoArgs
    # CHECK: transform.sequence
    # CHECK: transform.structured.pad
    # CHECK-NOT: copy_back_op
    # CHECK-NOT: nofold_flags
    # CHECK-NOT: pad_to_multiple_of
    # CHECK-NOT: padding_dimensions
    # CHECK-NOT: padding_values
    # CHECK-NOT: transpose_paddings


@run
@create_sequence
def testPadOpArgs(target):
    structured.PadOp(
        target,
        pad_to_multiple_of=[128],
        padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
        padding_dimensions=Attribute.parse("[1]"),
        nofold_flags=[0],
        transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
        copy_back_op="linalg.copy",
    )
    # CHECK-LABEL: TEST: testPadOpArgs
    # CHECK: transform.sequence
    # CHECK: transform.structured.pad
    # CHECK-DAG: pad_to_multiple_of [128]
    # CHECK-DAG: copy_back_op = "linalg.copy"
    # CHECK-DAG: nofold_flags = [0]
    # CHECK-DAG: padding_dimensions = [1]
    # CHECK-DAG: padding_values = [4.200000e+01 : f32, "0"]
    # CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]


@run
@create_sequence
def testPadOpArgsParam(target):
    structured.PadOp(
        target,
        pad_to_multiple_of=[constant_param(128), Attribute.parse("2"), 10],
        padding_dimensions=Attribute.parse("[0, 1, 2]"),
    )
    # CHECK-LABEL: TEST: testPadOpArgsParam
    # CHECK: transform.sequence
    # CHECK-DAG: %[[P:.*]] = transform.param.constant 128
    # CHECK: transform.structured.pad
    # CHECK-DAG: pad_to_multiple_of [%[[P]], 2, 10]
    # CHECK-DAG: padding_dimensions = [0, 1, 2]


@run
@create_sequence
def testScalarize(target):
    structured.ScalarizeOp(target)
    # CHECK-LABEL: TEST: testScalarize
    # CHECK: transform.structured.scalarize


@run
@create_sequence
def testSplit(target):
    handle = structured.SplitOp(target, dimension=1, chunk_sizes=42)
    split = transform.SplitHandleOp(
        [transform.AnyOpType.get(), transform.AnyOpType.get()], handle
    )
    structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
    # CHECK-LABEL: TEST: testSplit
    # CHECK: %[[G:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
    # CHECK: %[[F:.+]]:2 = split_handle %[[G]]
    # CHECK: transform.structured.split %[[F]]#0 after %[[F]]#1 {dimension = 3


@run
@create_sequence
def testTileCompact(target):
    structured.TileUsingForOp(target, sizes=[4, 8], interchange=[0, 1])
    # CHECK-LABEL: TEST: testTileCompact
    # CHECK: transform.sequence
    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 8]
    # CHECK: interchange = [0, 1]


@run
@create_sequence
def testTileAttributes(target):
    attr = DenseI64ArrayAttr.get([4, 8])
    ichange = DenseI64ArrayAttr.get([0, 1])
    structured.TileUsingForOp(target, sizes=attr, interchange=ichange)
    # CHECK-LABEL: TEST: testTileAttributes
    # CHECK: transform.sequence
    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 8]
    # CHECK: interchange = [0, 1]


@run
@create_sequence
def testTileZero(target):
    structured.TileUsingForOp(target, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
    # CHECK-LABEL: TEST: testTileZero
    # CHECK: transform.sequence
    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, 0, 2, 0]
    # CHECK: interchange = [0, 1, 2, 3]


@run
def testTileDynamic():
    with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
    with InsertionPoint(with_pdl.body):
        sequence = transform.SequenceOp(
            transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
        )
        with InsertionPoint(sequence.body):
            m1 = transform_pdl.PDLMatchOp(
                pdl.OperationType.get(), sequence.bodyTarget, "first"
            )
            m2 = transform_pdl.PDLMatchOp(
                pdl.OperationType.get(), sequence.bodyTarget, "second"
            )
            structured.TileUsingForOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
            transform.YieldOp()
    # CHECK-LABEL: TEST: testTileDynamic
    # CHECK: %[[FIRST:.+]] = pdl_match
    # CHECK: %[[SECOND:.+]] = pdl_match
    # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile_using_for %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]


@run
@create_sequence
def testTileExplicitLoopTypeSingle(target):
    structured.TileUsingForOp(
        transform.OperationType.get("scf.for"), target, sizes=[2, 3, 4]
    )
    # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
    # CHECK: = transform.structured.tile_using_for %{{.*}} : (!{{.*}}) ->
    # CHECK-COUNT-3: !transform.op<"scf.for">


@run
@create_sequence
def testTileExplicitLoopTypeAll(target):
    types = [
        transform.OperationType.get(x)
        for x in ["scf.for", "scf.parallel", "scf.forall"]
    ]
    structured.TileUsingForOp(types, target, sizes=[2, 3, 4])
    # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
    # CHECK: = transform.structured.tile
    # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
    # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">


@run
@create_sequence
def testTileScalable(target):
    structured.TileUsingForOp(
        target,
        sizes=[4, [2]],
    )
    # CHECK-LABEL: TEST: testTileScalable
    # CHECK: transform.sequence
    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile_using_for %{{.*}}[4, [2]]


@run
@create_sequence
def testTileToForallCompact(target):
    matmul = transform.CastOp(transform.OperationType.get("linalg.matmul"), target)
    structured.TileUsingForallOp(matmul, num_threads=[2, 3, 4])
    # CHECK-LABEL: TEST: testTileToForallCompact
    # CHECK: = transform.structured.tile_using_forall
    # CHECK-SAME: num_threads [2, 3, 4]
    # CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)


@run
@create_sequence
def testTileToForallLoopsAndTileOpTypes(target):
    structured.TileUsingForallOp(
        transform.OperationType.get("scf.forall"),  # loops_type
        transform.OperationType.get("linalg.matmul"),  # tiled_op_type
        target,
        num_threads=[2, 3, 4],
    )
    # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes
    # CHECK: = transform.structured.tile_using_forall
    # CHECK-SAME: num_threads [2, 3, 4]
    # CHECK-SAME: (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">)


@run
@create_sequence
def testTileToForallTileSizes(target):
    structured.TileUsingForallOp(target, tile_sizes=[2, 3, 4])
    # CHECK-LABEL: TEST: testTileToForallTileSizes
    # CHECK: = transform.structured.tile_using_forall
    # CHECK-SAME: tile_sizes [2, 3, 4]


@run
@create_sequence
def testTileToForallMixedDynamic(target):
    n = structured.MatchOp.match_op_names(target, ["test.dummy"])
    structured.TileUsingForallOp(target, num_threads=[n, 3, 4])
    # CHECK-LABEL: TEST: testTileToForallMixedDynamic
    # CHECK: = transform.structured.tile_using_forall
    # CHECK-SAME: num_threads [%{{.*}}, 3, 4] : (!transform.any_op, !transform.any_op)


@run
@create_sequence
def testTileToForallPackedDynamic(target):
    n = structured.MatchOp.match_op_names(target, ["test.dummy"])
    structured.TileUsingForallOp(target, num_threads=n)
    # CHECK-LABEL: TEST: testTileToForallPackedDynamic
    # CHECK: = transform.structured.tile_using_forall
    # CHECK-SAME: num_threads *(%0) : (!transform.any_op, !transform.any_op)


@run
@create_sequence
def testTileToForallMapping(target):
    mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
    structured.TileUsingForallOp(target, num_threads=[2, 3], mapping=mapping)
    # CHECK-LABEL: TEST: testTileToForallMapping
    # CHECK: = transform.structured.tile_using_forall
    # CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>]


@run
@create_sequence
def testVectorizeChildrenAndApplyPatternsAllAttrs(target):
    structured.VectorizeChildrenAndApplyPatternsOp(
        target,
        disable_multi_reduction_to_contract_patterns=True,
        disable_transfer_permutation_map_lowering_patterns=True,
        vectorize_nd_extract=True,
        vectorize_padding=True,
    )
    # CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsAllAttrs
    # CHECK: transform.sequence
    # CHECK: = transform.structured.vectorize
    # CHECK-SAME: disable_multi_reduction_to_contract_patterns
    # CHECK-SAME: disable_transfer_permutation_map_lowering_patterns
    # CHECK-SAME: vectorize_nd_extract
    # CHECK-SAME: vectorize_padding


@run
@create_sequence
def testVectorizeChildrenAndApplyPatternsNoAttrs(target):
    structured.VectorizeChildrenAndApplyPatternsOp(
        target,
        disable_multi_reduction_to_contract_patterns=False,
        disable_transfer_permutation_map_lowering_patterns=False,
        vectorize_nd_extract=False,
        vectorize_padding=False,
    )
    # CHECK-LABEL: TEST: testVectorizeChildrenAndApplyPatternsNoAttrs
    # CHECK: transform.sequence
    # CHECK: = transform.structured.vectorize
    # CHECK-NOT: disable_multi_reduction_to_contract_patterns
    # CHECK-NOT: disable_transfer_permutation_map_lowering_patterns
    # CHECK-NOT: vectorize_nd_extract
    # CHECK-NOT: vectorize_padding


@run
@create_sequence
def testMatchInterfaceEnum(target):
    names = ArrayAttr.get([StringAttr.get("test.dummy")])
    result_type = transform.AnyOpType.get()
    fused = structured.MatchOp.__base__(
        result_type,
        target,
        ops=names,
        interface=structured.MatchInterfaceEnum.LinalgOp,
    )
    # CHECK-LABEL: TEST: testMatchInterfaceEnum
    # CHECK: transform.sequence
    # CHECK: = transform.structured.match
    # CHECK: interface{LinalgOp}


@run
@create_sequence
def testMatchInterfaceEnumReplaceAttributeBuilder(target):
    @register_attribute_builder("MatchInterfaceEnum", replace=True)
    def match_interface_enum(x, context):
        if x == "LinalgOp":
            y = 0
        elif x == "TilingInterface":
            y = 1
        return IntegerAttr.get(IntegerType.get_signless(32, context=context), y)

    names = ArrayAttr.get([StringAttr.get("test.dummy")])
    result_type = transform.AnyOpType.get()
    fused = structured.MatchOp.__base__(
        result_type,
        target,
        ops=names,
        interface="TilingInterface",
    )
    # CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder
    # CHECK: transform.sequence
    # CHECK: = transform.structured.match
    # CHECK: interface{TilingInterface}
