| # RUN: %PYTHON %s | FileCheck %s |
| |
| from mlir.ir import * |
| from mlir.dialects import transform |
| from mlir.dialects.transform import xegpu |
| from mlir.dialects.transform import structured, AnyValueType |
| |
| |
| def run(f): |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| print("\nTEST:", f.__name__) |
| f() |
| print(module) |
| return f |
| |
| |
| @run |
| def getDescOpDefaultIndex(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) |
| desc_handle = xegpu.get_desc_op(operand) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: getDescOpDefaultIndex |
| # CHECK: transform.xegpu.get_desc_op % |
| |
| |
| @run |
| def setDescLayoutMinimal(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.create_nd_tdesc"), |
| ) |
| with InsertionPoint(sequence.body): |
| xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16]) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: setDescLayoutMinimal |
| # CHECK: %0 = transform.xegpu.set_desc_layout % |
| # CHECK: sg_layout = [6, 4] |
| # CHECK: sg_data = [32, 16] |
| |
| |
| @run |
| def setDescLayoutInstData(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.create_nd_tdesc"), |
| ) |
| with InsertionPoint(sequence.body): |
| xegpu.set_desc_layout( |
| sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16] |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: setDescLayoutInstData |
| # CHECK: %0 = transform.xegpu.set_desc_layout % |
| # CHECK: sg_layout = [6, 4] |
| # CHECK: sg_data = [32, 16] |
| # CHECK: inst_data = [8, 16] |
| |
| |
| @run |
| def setDescLayoutSlice(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.create_nd_tdesc"), |
| ) |
| with InsertionPoint(sequence.body): |
| xegpu.set_desc_layout( |
| sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0] |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: setDescLayoutSlice |
| # CHECK: %0 = transform.xegpu.set_desc_layout % |
| # CHECK: sg_layout = [6, 4] |
| # CHECK: sg_data = [32, 16] |
| # CHECK: slice_dims = [0] |
| |
| |
| @run |
| def setOpLayoutAttrOperandMinimal(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| xegpu.set_op_layout_attr( |
| sequence.bodyTarget, |
| sg_layout=[6, 4], |
| sg_data=[32, 16], |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: setOpLayoutAttr |
| # CHECK: transform.xegpu.set_op_layout_attr % |
| # NO-CHECK: index = 0 |
| # NO-CHECK: result |
| # CHECK: sg_layout = [6, 4] |
| # CHECK: sg_data = [32, 16] |
| # NO-CHECK: inst_data |
| |
| |
| @run |
| def setOpLayoutAttrResult(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| xegpu.set_op_layout_attr( |
| sequence.bodyTarget, |
| index=0, |
| sg_layout=[6, 4], |
| sg_data=[32, 16], |
| inst_data=[8, 16], |
| result=True, |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: setOpLayoutAttrResult |
| # CHECK: transform.xegpu.set_op_layout_attr % |
| # NO-CHECK: index = 0 |
| # CHECK: result |
| # CHECK: sg_layout = [6, 4] |
| # CHECK: sg_data = [32, 16] |
| # CHECK: inst_data = [8, 16] |
| |
| |
| @run |
| def setOpLayoutAttrResultSlice(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| xegpu.set_op_layout_attr( |
| sequence.bodyTarget, |
| index=0, |
| sg_layout=[6, 4], |
| sg_data=[32, 16], |
| inst_data=[8, 16], |
| slice_dims=[0], |
| result=True, |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice |
| # CHECK: transform.xegpu.set_op_layout_attr % |
| # NO-CHECK: index = 0 |
| # CHECK: result |
| # CHECK: sg_layout = [6, 4] |
| # CHECK: sg_data = [32, 16] |
| # CHECK: inst_data = [8, 16] |
| # CHECK: slice_dims = [0] |
| |
| |
| @run |
| def setGPULaunchThreadsOp(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("gpu.launch"), |
| ) |
| with InsertionPoint(sequence.body): |
| xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1]) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: setGPULaunchThreadsOp |
| # CHECK: transform.xegpu.set_gpu_launch_threads |
| # CHECK: threads = [8, 4, 1] |
| |
| |
| @run |
| def insertPrefetch0(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) |
| xegpu.insert_prefetch( |
| operand, |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: insertPrefetch0 |
| # CHECK: %[[OPR:.*]] = get_operand |
| # CHECK: transform.xegpu.insert_prefetch %[[OPR]] |
| |
| |
| @run |
| def insertPrefetchNbPrefetch(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) |
| xegpu.insert_prefetch( |
| operand, |
| nb_prefetch=2, |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: insertPrefetchNbPrefetch |
| # CHECK: %[[OPR:.*]] = get_operand |
| # CHECK: transform.xegpu.insert_prefetch %[[OPR]] |
| # CHECK-SAME: nb_prefetch = 2 |
| |
| |
| @run |
| def insertPrefetchNbPrefetchParam(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) |
| int32_t = IntegerType.get_signless(32) |
| param_int32_t = transform.ParamType.get(int32_t) |
| nb_param = transform.ParamConstantOp( |
| param_int32_t, |
| IntegerAttr.get(int32_t, 2), |
| ) |
| xegpu.insert_prefetch( |
| operand, |
| nb_prefetch=nb_param, |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam |
| # CHECK: %[[OPR:.*]] = get_operand |
| # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2 |
| # CHECK: transform.xegpu.insert_prefetch %[[OPR]] |
| # CHECK-SAME: nb_prefetch = %[[PARAM_OP]] |
| |
| |
| @run |
| def ConvertLayoutMinimal(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) |
| xegpu.convert_layout( |
| operand, |
| input_sg_layout=[6, 4], |
| input_sg_data=[32, 16], |
| target_sg_layout=[6, 4], |
| target_sg_data=[8, 16], |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: ConvertLayoutMinimal |
| # CHECK: transform.xegpu.convert_layout % |
| # CHECK: input_sg_layout = [6, 4] |
| # CHECK: input_sg_data = [32, 16] |
| # CHECK: target_sg_layout = [6, 4] |
| # CHECK: target_sg_data = [8, 16] |
| |
| |
| @run |
| def ConvertLayout(): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.OperationType.get("xegpu.dpas"), |
| ) |
| with InsertionPoint(sequence.body): |
| operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1]) |
| xegpu.convert_layout( |
| operand, |
| input_sg_layout=[6, 4], |
| input_sg_data=[32, 32], |
| input_inst_data=[32, 16], |
| target_sg_layout=[6, 4], |
| target_sg_data=[32, 32], |
| target_inst_data=[8, 16], |
| ) |
| transform.YieldOp() |
| # CHECK-LABEL: TEST: ConvertLayout |
| # CHECK: transform.xegpu.convert_layout % |
| # CHECK: input_sg_layout = [6, 4] |
| # CHECK: input_sg_data = [32, 32] |
| # CHECK: input_inst_data = [32, 16] |
| # CHECK: target_sg_layout = [6, 4] |
| # CHECK: target_sg_data = [32, 32] |
| # CHECK: target_inst_data = [8, 16] |