blob: b1bb91ffc29721400e7e769ce33ff8b84a1c2106 [file] [log] [blame]
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
%readRow: index, %readCol: index,
%writeRow: index, %writeCol: index,
%fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
%cst = arith.constant 0.000000e+00 : f16
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
}
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {sharedMemoryLineSizeBytes = 128, defaultVectorSizeBits = 128}: (!transform.any_op) -> ()
transform.yield
} // @__transform_main
} // module