blob: 983eee732e2afe10dda36035c942ea1b9893aaf9 [file] [log] [blame]
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
%readRow: index, %readCol: index,
%writeRow: index, %writeCol: index,
%fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
%cst = arith.constant 0.000000e+00 : f16
// CHECK: [[shmA:%.+]] = memref.alloc
// CHECK: [[shmB:%.+]] = memref.alloc
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
}