| // RUN: mlir-opt %s -split-input-file --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | FileCheck %s |
| |
| // CHECK: @optimize_128x32xf16_32x128xf16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) |
| func.func @optimize_128x32xf16_32x128xf16(%arg0: memref<128x128xf16>, |
| %ldRow: index, %ldCol: index, |
| %stRow: index, %stCol: index, |
| %fragRow: index, %fragCol :index) |
| -> (vector<4x2xf16>, vector<4x2xf16>) { |
| // CHECK: [[shm:%.+]] = memref.alloc |
| // CHECK: [[shmB:%.+]] = memref.alloc |
| %shm = memref.alloc() : memref<128x32xf16, 3> |
| %shmB = memref.alloc() : memref<32x128xf16, 3> |
| |
| // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]] |
| // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]] |
| // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] |
| // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] |
| %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 |
| : memref<128x128xf16> to memref<128x32xf16, 3> |
| %1 = nvgpu.device_async_create_group %0 |
| nvgpu.device_async_wait %1 { numGroups = 1 : i32} |
| |
| // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]] |
| %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} |
| : memref<128x32xf16, 3> -> vector<4x2xf16> |
| |
| // CHECK: [[c15:%.+]] = arith.constant 15 : index |
| // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]] |
| // CHECK: [[c3:%.+]] = arith.constant 3 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c3]] |
| // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] |
| // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]] |
| %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 8 |
| : memref<128x128xf16> to memref<32x128xf16, 3> |
| %3 = nvgpu.device_async_create_group %0 |
| nvgpu.device_async_wait %1 { numGroups = 1 : i32} |
| |
| // CHECK: [[c15:%.+]] = arith.constant 15 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] |
| // CHECK: [[c3:%.+]] = arith.constant 3 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c3]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]] |
| %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} |
| : memref<32x128xf16, 3> -> vector<4x2xf16> |
| |
| return %mat, %matB: vector<4x2xf16>, vector<4x2xf16> |
| } |
| |
| |
| // ----- |
| |
| // CHECK: @optimize_64x16xf32_16x64xf32([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) |
| func.func @optimize_64x16xf32_16x64xf32(%arg0: memref<128x128xf32>, |
| %ldRow: index, %ldCol: index, |
| %stRow: index, %stCol: index, |
| %fragRow: index, %fragCol :index) |
| -> (vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32) { |
| // CHECK: [[shm:%.+]] = memref.alloc |
| // CHECK: [[shmB:%.+]] = memref.alloc |
| %shm = memref.alloc() : memref<64x16xf32, 3> |
| %shmB = memref.alloc() : memref<16x64xf32, 3> |
| |
| // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]] |
| // CHECK: [[c1:%.+]] = arith.constant 1 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c1]] |
| // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] |
| // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] |
| %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 4 |
| : memref<128x128xf32> to memref<64x16xf32, 3> |
| %1 = nvgpu.device_async_create_group %0 |
| nvgpu.device_async_wait %1 { numGroups = 1 : i32} |
| |
| // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| // CHECK: [[c1:%.+]] = arith.constant 1 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]] |
| %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} |
| : memref<64x16xf32, 3> -> vector<4x1xf32> |
| |
| // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| // CHECK: [[c1:%.+]] = arith.constant 1 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]] |
| %elem = memref.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3> |
| |
| // Verify vector operations. |
| |
| // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| // CHECK: [[c1:%.+]] = arith.constant 1 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: vector.load [[shm]][[[fragRow]], [[fragColPerm]]] |
| %elem2 = vector.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32> |
| |
| // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| // CHECK: [[c1:%.+]] = arith.constant 1 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: vector.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]] |
| vector.store %elem2, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32> |
| |
| // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| // CHECK: [[c1:%.+]] = arith.constant 1 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: memref.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]] |
| memref.store %elem, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3> |
| |
| // Verify 16x64xf32 memory size. |
| |
| // CHECK: [[c15:%.+]] = arith.constant 15 : index |
| // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]] |
| // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]] |
| // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] |
| // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]] |
| %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 4 |
| : memref<128x128xf32> to memref<16x64xf32, 3> |
| %3 = nvgpu.device_async_create_group %0 |
| nvgpu.device_async_wait %1 { numGroups = 1 : i32} |
| |
| // CHECK: [[c15:%.+]] = arith.constant 15 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] |
| // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]] |
| %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} |
| : memref<16x64xf32, 3> -> vector<4x1xf32> |
| |
| // CHECK: [[c15:%.+]] = arith.constant 15 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] |
| // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: memref.load [[shmB]][[[fragRow]], [[fragColPerm]]] |
| %elemB = memref.load %shmB[%fragRow, %fragCol] : memref<16x64xf32, 3> |
| |
| return %mat, %matB, %elem, %elem2, %elemB: vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32 |
| } |
| |
| |
| // ----- |
| |
| // Small column edge cases |
| |
| // CHECK: @small_column_size_f64([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) |
| func.func @small_column_size_f64(%arg0: memref<32x32xf64>, |
| %ldRow: index, %ldCol: index, |
| %stRow: index, %stCol: index, |
| %fragRow: index, %fragCol :index) |
| -> f64 { |
| // CHECK: [[shm:%.+]] = memref.alloc |
| %shm = memref.alloc() : memref<32x4xf64, 3> |
| |
| // CHECK: [[c4:%.+]] = arith.constant 4 : index |
| // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c4]] |
| // CHECK: [[c1:%.+]] = arith.constant 1 : index |
| // CHECK: [[xorBits:%.+]] = arith.shrui [[src_bits]], [[c1]] |
| // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] |
| // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] |
| %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 2 |
| : memref<32x32xf64> to memref<32x4xf64, 3> |
| %1 = nvgpu.device_async_create_group %0 |
| nvgpu.device_async_wait %1 { numGroups = 1 : i32} |
| |
| // CHECK: [[c6:%.+]] = arith.constant 4 : index |
| // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| // CHECK: [[c1:%.+]] = arith.constant 1 : index |
| // CHECK: [[xorBits:%.+]] = arith.shrui [[srcBits]], [[c1]] |
| // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] |
| // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]] |
| %el = memref.load %shm[%fragRow, %fragCol] : memref<32x4xf64, 3> |
| |
| return %el: f64 |
| } |
| |
| // CHECK: @too_small_column_size_f16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) |
| func.func @too_small_column_size_f16(%arg0: memref<128x128xf16>, |
| %ldRow: index, %ldCol: index, |
| %stRow: index, %stCol: index, |
| %fragRow: index, %fragCol :index) |
| -> vector<1x2xf16> { |
| // CHECK: [[shm:%.+]] = memref.alloc |
| %shm = memref.alloc() : memref<128x8xf16, 3> |
| |
| // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]] |
| %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 |
| : memref<128x128xf16> to memref<128x8xf16, 3> |
| %1 = nvgpu.device_async_create_group %0 |
| nvgpu.device_async_wait %1 { numGroups = 1 : i32} |
| |
| // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragCol]]] |
| %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false} |
| : memref<128x8xf16, 3> -> vector<1x2xf16> |
| |
| return %mat: vector<1x2xf16> |
| } |
| |
| // ----- |
| |
| // CHECK: @abort_if_subview([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) |
| func.func @abort_if_subview(%arg0: memref<128x128xf16>, |
| %ldRow: index, %ldCol: index, |
| %stRow: index, %stCol: index, |
| %fragRow: index, %fragCol :index) |
| -> vector<1x2xf16> { |
| // CHECK: [[shm:%.+]] = memref.alloc |
| %shm = memref.alloc() : memref<128x32xf16, 3> |
| // CHECK: [[shmView:%.+]] = memref.subview |
| %shmView = memref.subview %shm[0, 0][64, 32][1, 1] : memref<128x32xf16, 3> to memref<64x32xf16, 3> |
| |
| // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]] |
| %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 |
| : memref<128x128xf16> to memref<128x32xf16, 3> |
| %1 = nvgpu.device_async_create_group %0 |
| nvgpu.device_async_wait %1 { numGroups = 1 : i32} |
| |
| // CHECK: nvgpu.ldmatrix [[shmView]][[[fragRow]], [[fragCol]]] |
| %mat = nvgpu.ldmatrix %shmView[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false} |
| : memref<64x32xf16, 3> -> vector<1x2xf16> |
| |
| return %mat: vector<1x2xf16> |
| } |