| # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ |
| # RUN: %PYTHON %s | FileCheck %s |
| |
| # ===----------------------------------------------------------------------===// |
| # Chapter 5 : Warp Specialized GEMM with Tensor Core |
| # ===----------------------------------------------------------------------===// |
| # |
| # This program demonstrates a GEMM operation for `f32+=f16*f16`, utilizing the |
| # Warp Specialized method with a tile size of 128x128x64. The code completely |
| # parallelizes the two outermost loops into thread blocks. It launches two Warp |
| # Groups (256 threads in total): one for the producer and the other for the consumer. |
| # Each group takes a different control-flow. The producer thread group is responsible |
| # for loading data into shared memory, while the consumer group executes the Tensor |
| # Core GEMM operation and epilogue. |
| # |
| # for ti in range(M//128): # -> blockIdx.x |
| # for tj in range(N//128): # -> blockIdx.y |
| # with wg_producer: |
| # for tk in range(K//64): |
| # TMA_128x64_64x128... |
| # with wg_consumer: |
| # for tk in range(K//64): |
| # MMA_128x128x64... |
| # Epilogue.. |
| # |
| # This chapter demonstrates: |
| # 2 WG (warpgroups) |
| # Producer: |
| # 2.1.1 Wait MMA Barrier |
| # 2.1.1 Load TMA with TMA barrier |
| # 2.1.1 Arrive TMA barrier with txcount |
| # Consumer: |
| # Loop |
| # Wait TMA barrier |
| # Performs Tensor Core GEMM 64x128x64 by warpgroup |
| # Arrive MMA Barrier |
| # Epilogue |
| # Store fragmented registers to shared memory |
| # Store shared memory to global |
| # |
| # ===----------------------------------------------------------------------===// |
| |
| |
| from mlir import ir |
| from mlir.dialects import gpu, scf, nvgpu, nvvm |
| from mlir.extras import types as T |
| from tools.nvdsl import * |
| import numpy as np |
| |
| |
| def partition_shape(): |
| """ |
| Calculate the partition shape based on the block IDs. |
| |
| It parallelizes the two outermost loops into thread blocks. |
| for ti in range(M//128): # -> blockIdx.x |
| for tj in range(N//128): # -> blockIdx.y |
| D = 0 |
| for tk in range(K//64): |
| for i in range(128): |
| for j in range(128): |
| for k in range(64): |
| FMA |
| |
| Returns: |
| dimX (int): Dimension along the x-axis. |
| dimY (int): Dimension along the y-axis. |
| """ |
| bidx = gpu.block_id(gpu.Dimension.x) |
| bidy = gpu.block_id(gpu.Dimension.y) |
| dimX = bidx * TILE_M |
| dimY = bidy * TILE_N |
| return dimX, dimY |
| |
| |
| def tma_load( |
| mbar_group: Mbarriers, |
| a_tma: TMA, |
| b_tma: TMA, |
| slot, |
| stage, |
| num_stages, |
| p=None, |
| ): |
| """ |
| TMA loads two input matrices from global memory to shared memory. It performs the following operations: |
| |
| - tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64) |
| - tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64) |
| - tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64) |
| |
| mbarrier.arrive ta_count = 128x64x2x4 |
| """ |
| dimX, dimY = partition_shape() |
| |
| tidx = gpu.thread_id(gpu.Dimension.x) |
| begin_b = num_stages * get_type_size(a_tma.tma_memref) |
| size_tma_a = get_type_size(a_tma.tma_memref) |
| size_tma_b = get_type_size(b_tma.tma_memref) |
| ta_count = size_tma_a + (size_tma_b * 2) |
| |
| off_a = slot * size_tma_a |
| off_b = (slot * size_tma_a) + begin_b |
| off_b2 = off_b + size_tma_b |
| a_elem_ty = a_tma.tma_memref.element_type |
| b_elem_ty = b_tma.tma_memref.element_type |
| a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a) |
| b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b) |
| b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2) |
| |
| mbar_group[slot].arrive(ta_count, predicate=p) |
| p = (tidx % WARP_GROUP_SIZE) == 0 |
| c1 = stage * 64 |
| a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p) |
| b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p) |
| b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p) |
| |
| |
| def initialize(a_tma: TMA, b_tma: TMA, num_stages): |
| """ |
| Initialize mbarriers and prefetch TMA descriptors. |
| """ |
| tidx = gpu.thread_id(gpu.Dimension.x) |
| mbar_group_tma = Mbarriers(number_of_barriers=num_stages) |
| mbar_group_mma = Mbarriers(number_of_barriers=num_stages) |
| isThread0 = tidx == const(0) |
| with ir.InsertionPoint(scf.IfOp(isThread0).then_block): |
| for i in scf.for_(0, num_stages, 1): |
| mbar_group_tma[i].init(1) |
| mbar_group_mma[i].init(1) |
| scf.yield_([]) |
| a_tma.prefetch() |
| b_tma.prefetch() |
| scf.yield_([]) |
| |
| return mbar_group_tma, mbar_group_mma |
| |
| |
| def switch_phase(stage, phase, num_stages): |
| p = stage == (num_stages - 1) |
| phase = arith.select( |
| p, |
| (phase ^ const(True, ty=T.bool())), |
| phase, |
| ) |
| return phase |
| |
| |
| def producer_loop( |
| mbar_tma: Mbarriers, |
| mbar_mma: Mbarriers, |
| a_tma: TMA, |
| b_tma: TMA, |
| wg_me: Warpgroup, |
| num_stages, |
| ): |
| phase = const(True, ty=T.bool()) |
| |
| for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]): |
| stage = iv % num_stages |
| # Wait MMA to be done |
| mbar_mma[stage].try_wait(phase) |
| # New phase for mbarrier |
| phase = switch_phase(stage, phase, num_stages) |
| # TMA Load |
| tma_load(mbar_tma, a_tma, b_tma, stage, iv, num_stages, wg_me.is_wg_primary) |
| scf.yield_([phase]) |
| |
| |
| def consumer_loop( |
| mbar_tma: Mbarriers, |
| mbar_mma: Mbarriers, |
| a_tma: TMA, |
| b_tma: TMA, |
| wg_me: Warpgroup, |
| num_stages, |
| ): |
| begin_b = num_stages * get_type_size(a_tma.tma_memref) |
| |
| size_a = TILE_M * TILE_K * get_type_size(T.f16()) |
| |
| phase = const(False, ty=T.bool()) |
| A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma) |
| B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma) |
| D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32()) |
| |
| for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase]) |
| with ir.InsertionPoint(for_op.body): |
| phase = for_op.inner_iter_args[1] |
| iv = for_op.induction_variable |
| stage = iv % num_stages |
| |
| # Wait TMA for current stage |
| mbar_tma[stage].try_wait(phase) |
| |
| # Find shared memory slot |
| offset_a = stage * size_a |
| offset_b = offset_a + begin_b |
| a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a) |
| b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b) |
| |
| # Iterate input matrices, update accumulator |
| A.update_smem(a_smem) |
| B.update_smem(b_smem) |
| D.update_accumulator(for_op.inner_iter_args[0]) |
| |
| # Matrix Multiply |
| D += A @ B |
| |
| # MMA Barrier Arrive |
| p_arrive = (iv > 0) & wg_me.is_wg_primary |
| with ir.InsertionPoint(scf.IfOp(p_arrive).then_block): |
| barId = arith.select((stage == 0), const(num_stages - 1), (stage - 1)) |
| mbar_mma[barId].arrive() |
| scf.yield_([]) |
| |
| phase = switch_phase(stage, phase, num_stages) |
| scf.yield_([D.acc_op, phase]) |
| |
| nvvm.WgmmaWaitGroupSyncOp(0) |
| D.update_accumulator(for_op.results[0]) |
| return D |
| |
| |
| def epilogue(D: WGMMAMatrix, d_dev): |
| """ |
| Epilogue of the GEMM kernel. It stores the fragmented registers to global memory. |
| |
| MatrixAccumulator D # Fragmented results |
| store D -> Shared Memory # Store Shared Memory |
| Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory |
| |
| """ |
| tidx = gpu.thread_id(gpu.Dimension.x) |
| dimX, dimY = partition_shape() |
| # s = tidx - WARP_GROUP_SIZE |
| # debug_print("[Epilogue] store to global memory @ s={}", s) |
| |
| d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32()) |
| d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1]) |
| |
| # Store (registers -> shared memory) |
| D.store_accumulator(d_smem) |
| gpu.barrier() |
| |
| # Store (shared memory --> global memory) |
| for i in scf.for_(0, TILE_M, 1): |
| val = memref.load(d_smem, [i, tidx]) |
| memref.store(val, d_gmem, [i, tidx]) |
| scf.yield_([]) |
| |
| |
| @NVDSL.mlir_func |
| def gemm_warp_specialized(a, b, d, num_stages): |
| token_ty = gpu.AsyncTokenType.get() |
| t1 = gpu.wait(token_ty, []) |
| a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) |
| b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) |
| d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) |
| t5 = gpu.memcpy(token_ty, [t4], a_dev, a) |
| t6 = gpu.memcpy(token_ty, [t5], b_dev, b) |
| t7 = gpu.wait(token_ty, [t6]) |
| |
| sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B |
| a_tma = TMA([128, 64], a.type, swizzle=sw) |
| b_tma = TMA([64, 64], b.type, swizzle=sw) |
| a_tma.create_descriptor(a_dev) |
| b_tma.create_descriptor(b_dev) |
| |
| grid = [(M // TILE_M), (N // TILE_N), 1] |
| block = [256, 1, 1] |
| |
| size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K |
| size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K |
| smem_size_in_bytes = (size_a + size_b) * num_stages |
| |
| @NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes) |
| def gemm_warp_specialized_kernel(): |
| # Init Warpgroups |
| wg_producer = Warpgroup(primary_thread=128, register_size=40) |
| wg_consumer = Warpgroup(primary_thread=0, register_size=232) |
| |
| # Initialize mbarriers and prefetch TMA descriptors |
| mbar_mma, mbar_tma = initialize(a_tma, b_tma, num_stages) |
| |
| # Producer performs TMA |
| with wg_producer: |
| producer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_producer, num_stages) |
| |
| # Consumer performs MMA/Tensor Core |
| with wg_consumer: |
| D = consumer_loop(mbar_tma, mbar_mma, a_tma, b_tma, wg_consumer, num_stages) |
| epilogue(D, d_dev) |
| |
| gemm_warp_specialized_kernel() |
| |
| t8 = gpu.memcpy(token_ty, [t7], d, d_dev) |
| gpu.wait(None, [t8]) |
| |
| |
| # Python pass arguments to MLIR |
| N = 256 |
| M = 512 |
| K = 1024 |
| TILE_M = 128 |
| TILE_N = 128 |
| TILE_K = 64 |
| a = np.random.randn(M, K).astype(np.float16) |
| b = np.random.randn(K, N).astype(np.float16) |
| d = np.zeros((M, N), np.float32) |
| |
| gemm_warp_specialized(a, b, d, num_stages=7) |
| |
| |
| # Verify MLIR with reference computation |
| ref_d = a.astype(np.float16) @ b.astype(np.float16) |
| np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) |
| |
| |
| print("PASS") |
| # CHECK-NOT: Mismatched elements |