| //===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| |
| using namespace mlir; |
| using namespace mlir::nvgpu; |
| |
| /// There are always 4 threads per [128|256|512] bit row. |
| static constexpr int64_t kThreadsPerRow = 4; |
| static constexpr int64_t kNumRowsPerTile = 8; |
| |
| static bool isAccumulatorOrResult(MatMulOperandRole operandType) { |
| return operandType == MatMulOperandRole::C; |
| } |
| |
| /// Returns the number of registers which compose a matrix fragment held by a |
| /// single thread. |
| static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { |
| int64_t lineSize = inferTileWidthInBits(type); |
| auto shape = type.vectorType.getShape(); |
| return (shape[0] / kNumRowsPerTile) * |
| (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) / |
| lineSize; |
| } |
| |
| /// Returns the number of 8 x [128|256|512] bit tiles that compose the given |
| /// operand shape. |
| static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape, |
| Type elementType, |
| int64_t lineSizeBits) { |
| // For each 8x128bit square, a thread is responsible for one 32bit register. |
| return {operandShape[0] / kNumRowsPerTile, |
| (operandShape[1] * elementType.getIntOrFloatBitWidth()) / |
| lineSizeBits}; |
| } |
| |
| /// Returns the first user of the `op` that is vector.contract. If no |
| /// vector.contract user exists, return failure. |
| FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) { |
| for (Operation *user : op->getUsers()) { |
| if (auto contractOp = dyn_cast<vector::ContractionOp>(user)) |
| return contractOp; |
| } |
| return failure(); |
| } |
| |
| FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) { |
| WarpMatrixInfo info; |
| |
| // Determine the vector type at warp-level. |
| if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) { |
| info.vectorType = writeOp.getVectorType(); |
| } else if (isa<vector::TransferReadOp, vector::ContractionOp, |
| vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) { |
| info.vectorType = cast<VectorType>(op->getResult(0).getType()); |
| } else { |
| return op->emitError() |
| << "unhandled operation type in nvgpu.mma.sync conversion path"; |
| } |
| |
| // Determine the operand role. We assume it is an accumulator/result unless it |
| // is directly consumed by a `vector.contract` op. |
| info.operandRole = MatMulOperandRole::C; |
| FailureOr<vector::ContractionOp> contractOp = getUserContract(op); |
| if (failed(contractOp)) |
| return info; |
| |
| if ((*contractOp).getLhs() == op->getResult(0)) |
| info.operandRole = MatMulOperandRole::A; |
| else if ((*contractOp).getRhs() == op->getResult(0)) |
| info.operandRole = MatMulOperandRole::B; |
| |
| return info; |
| } |
| |
| int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) { |
| bool isAcc = isAccumulatorOrResult(type.operandRole); |
| Type elType = type.vectorType.getElementType(); |
| if (isAcc && elType.getIntOrFloatBitWidth() == 32) { |
| return 256; |
| } |
| if (elType.getIntOrFloatBitWidth() == 64) { |
| return isAcc ? 512 : 256; |
| } |
| return 128; |
| } |
| |
| FailureOr<FragmentElementInfo> |
| nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) { |
| MLIRContext *ctx = type.vectorType.getContext(); |
| const bool isAccum = isAccumulatorOrResult(type.operandRole); |
| |
| Type elType = type.vectorType.getElementType(); |
| if (elType.isF16()) { |
| return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32, |
| inferNumRegistersPerMatrixFragment(type)}; |
| } |
| |
| // f64 operand |
| Type f64Ty = Float64Type::get(ctx); |
| if (elType.isF64()) { |
| return isAccum |
| ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128, |
| inferNumRegistersPerMatrixFragment(type)} |
| : FragmentElementInfo{f64Ty, 1, 64, |
| inferNumRegistersPerMatrixFragment(type)}; |
| } |
| |
| // int8 operand |
| if (elType.isInteger(8)) { |
| return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4, |
| 32, inferNumRegistersPerMatrixFragment(type)}; |
| } |
| |
| // int4 operand |
| if (elType.isInteger(4)) { |
| return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8, |
| 32, inferNumRegistersPerMatrixFragment(type)}; |
| } |
| |
| // Integer 32bit acc operands |
| if (elType.isInteger(32)) { |
| return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2, |
| 64, inferNumRegistersPerMatrixFragment(type)}; |
| } |
| |
| // Floating point 32bit operands |
| if (elType.isF32()) { |
| Type f32Ty = Float32Type::get(ctx); |
| return isAccum |
| ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64, |
| inferNumRegistersPerMatrixFragment(type)} |
| : FragmentElementInfo{f32Ty, 1, 32, |
| inferNumRegistersPerMatrixFragment(type)}; |
| } |
| return failure(); |
| } |
| |
| static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, |
| Type elementType, |
| ArrayRef<int64_t> operandShape, |
| bool isAccumulator, |
| int64_t elementsPerRegister, |
| AffineExpr logicalValueId) { |
| const int64_t elementsPerLine = |
| lineSize / elementType.getIntOrFloatBitWidth(); |
| const std::array<int64_t, 2> num8x128bTiles = |
| getTileShape(operandShape, elementType, lineSize); |
| AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); |
| return AffineMap::get( |
| 2, 0, |
| {(registerIdx % num8x128bTiles[0]) * 8, |
| (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, |
| elementType.getContext()); |
| } |
| |
| FailureOr<AffineMap> |
| nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, |
| const WarpMatrixInfo &fragmentType) { |
| Type elementType = fragmentType.vectorType.getElementType(); |
| ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape(); |
| FailureOr<nvgpu::FragmentElementInfo> regInfo = |
| getMmaSyncRegisterType(fragmentType); |
| if (failed(regInfo)) |
| return failure(); |
| |
| const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth(); |
| const int64_t elementsPerRegister = |
| regInfo->registerWidthBits / elementBitWidth; |
| const int64_t lineSize = inferTileWidthInBits(fragmentType); |
| |
| AffineExpr laneId, logicalValueIdDim; |
| bindDims(builder.getContext(), laneId, logicalValueIdDim); |
| |
| // Determine what register logicalValueId corresponds to. Use that as a |
| // linear index into the coordinate mapping `index -> (tile row, tile col)`. |
| AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( |
| lineSize, elementType, operandShape, |
| isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister, |
| logicalValueIdDim); |
| |
| auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap { |
| return AffineMap::get(2, 0, dimExprs, builder.getContext()); |
| }; |
| |
| auto tileRow = registerIndexToTileCoord.getResult(0); |
| auto tileCol = registerIndexToTileCoord.getResult(1); |
| return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), |
| tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + |
| (logicalValueIdDim % elementsPerRegister)}); |
| } |
| |
| FailureOr<nvgpu::LdMatrixParams> |
| nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) { |
| LdMatrixParams params; |
| Type elType = type.vectorType.getElementType(); |
| params.fragmentType = type.vectorType; |
| if (type.operandRole == MatMulOperandRole::A || |
| type.operandRole == MatMulOperandRole::C) { |
| params.targetLayout = NVVM::MMALayout::row; |
| } else { |
| params.targetLayout = NVVM::MMALayout::col; |
| } |
| ArrayRef<int64_t> shape = type.vectorType.getShape(); |
| params.contiguousDimType = transpose ? vector::IteratorType::parallel |
| : vector::IteratorType::reduction; |
| |
| if (params.contiguousDimType == vector::IteratorType::reduction) { |
| params.numTiles = (shape[0] / kNumRowsPerTile) * |
| ((shape[1] * elType.getIntOrFloatBitWidth()) / 128); |
| } else { |
| params.numTiles = (shape[1] / kNumRowsPerTile) * |
| ((shape[0] * elType.getIntOrFloatBitWidth()) / 128); |
| } |
| |
| if (params.numTiles == 0) |
| return failure(); |
| |
| return params; |
| } |
| |
| FailureOr<AffineMap> |
| nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, |
| const LdMatrixParams ¶ms) { |
| // One thread per 128b row. |
| const int bitsPerElement = static_cast<int>( |
| params.fragmentType.getElementType().getIntOrFloatBitWidth()); |
| const int kElementsPer128b = (128 / bitsPerElement); |
| ArrayRef<int64_t> operandShape = params.fragmentType.getShape(); |
| AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); |
| |
| auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap { |
| return AffineMap::get(1, 0, dimExprs, builder.getContext()); |
| }; |
| |
| // Index `idx` in vectorType `operandShape` maps to the strided dimension of |
| // the `srcMemref` memory of the LdMatrixOp. |
| int idx = |
| (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1; |
| |
| // Affine expr in strided and contiguous dimension encodes the coordinate |
| // mapping for the element a thread points to for warp-wide LdMatrixOp. |
| AffineExpr strided = d0 % (operandShape[idx]); |
| AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b); |
| |
| // This case corresponds to row-major matrixA or col-major matrixB or |
| // row-major matrixC. This is when the memory layout in `srcMemref` |
| // match mma.sync hardware vector register operand layout. |
| if (params.contiguousDimType == vector::IteratorType::reduction) |
| return makeMap({strided, contiguous}); |
| |
| // This case corresponds to col-major matrixA or row-major matrixB or |
| // col-major matrixC. This is when the memory layout in `srcMemref` does not |
| // match mma.sync hardware vector register operand layout. |
| if (params.contiguousDimType == vector::IteratorType::parallel) |
| return makeMap({contiguous, strided}); |
| |
| return failure(); |
| } |
| |
| bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) { |
| if (op.getMask() || op.hasOutOfBoundsDim()) |
| return false; |
| VectorType type = op.getType(); |
| // The result type should be 2D. Note that it is possible to expand support so |
| // that we are robust to extra unit dimensions that failed to fold, but that |
| // would significantly increase downstream code complexity in the conversion |
| // step. For now, we rely on other patterns to ensure canonical 2D form is |
| // used when targeting the `nvgpu.mma.sync` lowering path. |
| if (!type.hasStaticShape() || type.getRank() != 2) |
| return false; |
| |
| // Currently we can't support reads on tensor types because we need stride |
| // information to ensure correctness of downstream assumptions. It is possible |
| // to enable this if caller can assert that tensor will be lowered in a |
| // particular manner. |
| auto sourceType = dyn_cast<MemRefType>(op.getBase().getType()); |
| if (!sourceType) |
| return false; |
| |
| // Check that the last dimension of the read is contiguous. Note that it is |
| // possible to expand support for this by scalarizing all the loads during |
| // conversion. |
| auto [strides, offset] = sourceType.getStridesAndOffset(); |
| return strides.back() == 1; |
| } |
| |
| bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) { |
| if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0) |
| return false; |
| VectorType type = op.getVectorType(); |
| if (!type.hasStaticShape() || type.getRank() != 2) |
| return false; |
| // TODO: Currently we rely on lowering to a `vector.store` operation. We could |
| // support the transposed write case by lowering to scalarized `memref.store` |
| // operations. |
| if (!op.getPermutationMap().isMinorIdentity()) |
| return false; |
| // Currently we can't support reads on tensor types because we need stride |
| // information to ensure correctness of downstream assumptions. |
| auto sourceType = dyn_cast<MemRefType>(op.getBase().getType()); |
| if (!sourceType) |
| return false; |
| |
| // Check that the last dimension of the target memref is contiguous. Note that |
| // it is possible to expand support for this by scalarizing all the stores |
| // during conversion. |
| auto [strides, offset] = sourceType.getStridesAndOffset(); |
| return strides.back() == 1; |
| } |