| //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" |
| |
| #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Pass/Pass.h" |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include <optional> |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL |
| #include "mlir/Conversion/Passes.h.inc" |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace mlir::amdgpu; |
| |
| static Value createI32Constant(ConversionPatternRewriter &rewriter, |
| Location loc, int32_t value) { |
| Type llvmI32 = rewriter.getI32Type(); |
| return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value); |
| } |
| |
| static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, |
| bool value) { |
| Type llvmI1 = rewriter.getI1Type(); |
| return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value); |
| } |
| |
| namespace { |
| /// Define lowering patterns for raw buffer ops |
| template <typename GpuOp, typename Intrinsic> |
| struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { |
| RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| static constexpr uint32_t maxVectorOpWidth = 128; |
| |
| LogicalResult |
| matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = gpuOp.getLoc(); |
| Value memref = adaptor.getMemref(); |
| Value unconvertedMemref = gpuOp.getMemref(); |
| MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType()); |
| |
| if (chipset.majorVersion < 9) |
| return gpuOp.emitOpError("raw buffer ops require GCN or higher"); |
| |
| Value storeData = adaptor.getODSOperands(0)[0]; |
| if (storeData == memref) // no write component to this op |
| storeData = Value(); |
| Type wantedDataType; |
| if (storeData) |
| wantedDataType = storeData.getType(); |
| else |
| wantedDataType = gpuOp.getODSResults(0)[0].getType(); |
| |
| Value atomicCmpData = Value(); |
| // Operand index 1 of a load is the indices, trying to read them can crash. |
| if (storeData) { |
| Value maybeCmpData = adaptor.getODSOperands(1)[0]; |
| if (maybeCmpData != memref) |
| atomicCmpData = maybeCmpData; |
| } |
| |
| Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); |
| |
| Type i32 = rewriter.getI32Type(); |
| Type llvmI32 = this->typeConverter->convertType(i32); |
| Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type()); |
| |
| int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8; |
| Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); |
| |
| // If we want to load a vector<NxT> with total size <= 32 |
| // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 |
| // and the total load size is >= 32, use a vector load of N / (bitsize(T) / |
| // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, |
| // so bitcast any floats to integers. On top of all this, cast bfloat |
| // (vectors) to i16 since the backend doesn't currently support bfloat on |
| // these operations. |
| Type llvmBufferValType = llvmWantedDataType; |
| if (wantedDataType.isBF16()) |
| llvmBufferValType = rewriter.getI16Type(); |
| if (auto wantedVecType = dyn_cast<VectorType>(wantedDataType)) |
| if (wantedVecType.getElementType().isBF16()) |
| llvmBufferValType = wantedVecType.clone(rewriter.getI16Type()); |
| if (atomicCmpData) { |
| if (isa<VectorType>(wantedDataType)) |
| return gpuOp.emitOpError("vector compare-and-swap does not exist"); |
| if (auto floatType = dyn_cast<FloatType>(wantedDataType)) |
| llvmBufferValType = this->getTypeConverter()->convertType( |
| rewriter.getIntegerType(floatType.getWidth())); |
| } |
| if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) { |
| uint32_t elemBits = dataVector.getElementTypeBitWidth(); |
| uint32_t totalBits = elemBits * dataVector.getNumElements(); |
| if (totalBits > maxVectorOpWidth) |
| return gpuOp.emitOpError( |
| "Total width of loads or stores must be no more than " + |
| Twine(maxVectorOpWidth) + " bits, but we call for " + |
| Twine(totalBits) + |
| " bits. This should've been caught in validation"); |
| if (elemBits < 32) { |
| if (totalBits > 32) { |
| if (totalBits % 32 != 0) |
| return gpuOp.emitOpError("Load or store of more than 32-bits that " |
| "doesn't fit into words. Can't happen\n"); |
| llvmBufferValType = this->typeConverter->convertType( |
| VectorType::get(totalBits / 32, i32)); |
| } else { |
| llvmBufferValType = this->typeConverter->convertType( |
| rewriter.getIntegerType(totalBits)); |
| } |
| } |
| } |
| |
| SmallVector<Value, 6> args; |
| if (storeData) { |
| if (llvmBufferValType != llvmWantedDataType) { |
| Value castForStore = |
| rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); |
| args.push_back(castForStore); |
| } else { |
| args.push_back(storeData); |
| } |
| } |
| |
| if (atomicCmpData) { |
| if (llvmBufferValType != llvmWantedDataType) { |
| Value castForCmp = rewriter.create<LLVM::BitcastOp>( |
| loc, llvmBufferValType, atomicCmpData); |
| args.push_back(castForCmp); |
| } else { |
| args.push_back(atomicCmpData); |
| } |
| } |
| |
| // Construct buffer descriptor from memref, attributes |
| int64_t offset = 0; |
| SmallVector<int64_t, 5> strides; |
| if (failed(getStridesAndOffset(memrefType, strides, offset))) |
| return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); |
| |
| MemRefDescriptor memrefDescriptor(memref); |
| |
| Value ptr = memrefDescriptor.alignedPtr(rewriter, loc); |
| // The stride value is always 0 for raw buffers. This also disables |
| // swizling. |
| Value stride = rewriter.create<LLVM::ConstantOp>( |
| loc, llvmI16, rewriter.getI16IntegerAttr(0)); |
| Value numRecords; |
| if (memrefType.hasStaticShape()) { |
| numRecords = createI32Constant( |
| rewriter, loc, |
| static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth)); |
| } else { |
| Value maxIndex; |
| for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { |
| Value size = memrefDescriptor.size(rewriter, loc, i); |
| Value stride = memrefDescriptor.stride(rewriter, loc, i); |
| stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst); |
| Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); |
| maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex, |
| maxThisDim) |
| : maxThisDim; |
| } |
| numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex); |
| } |
| |
| // Flag word: |
| // bits 0-11: dst sel, ignored by these intrinsics |
| // bits 12-14: data format (ignored, must be nonzero, 7=float) |
| // bits 15-18: data format (ignored, must be nonzero, 4=32bit) |
| // bit 19: In nested heap (0 here) |
| // bit 20: Behavior on unmap (0 means "return 0 / ignore") |
| // bits 21-22: Index stride for swizzles (N/A) |
| // bit 23: Add thread ID (0) |
| // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) |
| // bits 25-26: Reserved (0) |
| // bit 27: Buffer is non-volatile (CDNA only) |
| // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = |
| // none, 3 = either swizzles or testing against offset field) RDNA only |
| // bits 30-31: Type (must be 0) |
| uint32_t flags = (7 << 12) | (4 << 15); |
| if (chipset.majorVersion >= 10) { |
| flags |= (1 << 24); |
| uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2; |
| flags |= (oob << 28); |
| } |
| Value flagsConst = createI32Constant(rewriter, loc, flags); |
| Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); |
| Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>( |
| loc, rsrcType, ptr, stride, numRecords, flagsConst); |
| args.push_back(resource); |
| |
| // Indexing (voffset) |
| Value voffset = createI32Constant(rewriter, loc, 0); |
| for (auto pair : llvm::enumerate(adaptor.getIndices())) { |
| size_t i = pair.index(); |
| Value index = pair.value(); |
| Value strideOp; |
| if (ShapedType::isDynamic(strides[i])) { |
| strideOp = rewriter.create<LLVM::MulOp>( |
| loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); |
| } else { |
| strideOp = |
| createI32Constant(rewriter, loc, strides[i] * elementByteWidth); |
| } |
| index = rewriter.create<LLVM::MulOp>(loc, index, strideOp); |
| voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index); |
| } |
| if (adaptor.getIndexOffset()) { |
| int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth; |
| Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset); |
| voffset = |
| voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) |
| : extraOffsetConst; |
| } |
| args.push_back(voffset); |
| |
| Value sgprOffset = adaptor.getSgprOffset(); |
| if (!sgprOffset) |
| sgprOffset = createI32Constant(rewriter, loc, 0); |
| if (ShapedType::isDynamic(offset)) |
| sgprOffset = rewriter.create<LLVM::AddOp>( |
| loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); |
| else if (offset > 0) |
| sgprOffset = rewriter.create<LLVM::AddOp>( |
| loc, sgprOffset, createI32Constant(rewriter, loc, offset)); |
| args.push_back(sgprOffset); |
| |
| // bit 0: GLC = 0 (atomics drop value, less coherency) |
| // bits 1-2: SLC, DLC = 0 (similarly) |
| // bit 3: swizzled (0 for raw) |
| args.push_back(createI32Constant(rewriter, loc, 0)); |
| |
| llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), |
| llvmBufferValType); |
| Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, |
| ArrayRef<NamedAttribute>()); |
| if (lowered->getNumResults() == 1) { |
| Value replacement = lowered->getResult(0); |
| if (llvmBufferValType != llvmWantedDataType) { |
| replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, |
| replacement); |
| } |
| rewriter.replaceOp(gpuOp, replacement); |
| } else { |
| rewriter.eraseOp(gpuOp); |
| } |
| return success(); |
| } |
| }; |
| |
| struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { |
| LDSBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| bool requiresInlineAsm = |
| chipset.majorVersion < 9 || |
| (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) || |
| (chipset.majorVersion == 11); |
| |
| if (requiresInlineAsm) { |
| auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), |
| LLVM::AsmDialect::AD_ATT); |
| const char *asmStr = |
| ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier"; |
| const char *constraints = ""; |
| rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>( |
| op, |
| /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), |
| /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, |
| /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, |
| /*operand_attrs=*/ArrayAttr()); |
| return success(); |
| } |
| constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8); |
| constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8); |
| // Left in place in case someone disables the inline ASM path or future |
| // chipsets use the same bit pattern. |
| constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4); |
| |
| int32_t ldsOnlyBits; |
| if (chipset.majorVersion == 11) |
| ldsOnlyBits = ldsOnlyBitsGfx11; |
| else if (chipset.majorVersion == 10) |
| ldsOnlyBits = ldsOnlyBitsGfx10; |
| else if (chipset.majorVersion <= 9) |
| ldsOnlyBits = ldsOnlyBitsGfx6789; |
| else |
| return op.emitOpError( |
| "don't know how to lower this for chipset major version") |
| << chipset.majorVersion; |
| |
| Location loc = op->getLoc(); |
| rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits); |
| rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op); |
| return success(); |
| } |
| }; |
| |
| struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> { |
| SchedBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op, |
| (uint32_t)op.getOpts()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| /// If `input` is a vector of bytes, concatentate those bytes in little-endian |
| /// order to form a single integer of size 8 * [vector length]. This works |
| /// around a wart in the AMDGPU intrinsics where operations that logically take |
| /// vectors of bytes instead integers. Since we do not want to expose this |
| /// implementation detail to MLIR, we correct for it here. |
| /// |
| /// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU |
| /// MFMA intrinsics pre-date the bfloat type. |
| static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, |
| Location loc, Value input) { |
| Type inputType = input.getType(); |
| if (auto vectorType = dyn_cast<VectorType>(inputType)) { |
| if (vectorType.getElementType().isBF16()) |
| return rewriter.create<LLVM::BitcastOp>( |
| loc, vectorType.clone(rewriter.getI16Type()), input); |
| |
| if (!vectorType.getElementType().isInteger(8)) |
| return input; |
| int64_t numBytes = vectorType.getNumElements(); |
| Type destType = rewriter.getIntegerType(numBytes * 8); |
| Value result = rewriter.create<LLVM::ConstantOp>( |
| loc, destType, rewriter.getIntegerAttr(destType, 0)); |
| for (int64_t i = 0; i < numBytes; ++i) { |
| Value idxConst = createI32Constant(rewriter, loc, i); |
| Value element = |
| rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst); |
| Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element); |
| Value shiftConst = rewriter.create<LLVM::ConstantOp>( |
| loc, destType, rewriter.getIntegerAttr(destType, i * 8)); |
| Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst); |
| result = rewriter.create<LLVM::OrOp>(loc, result, shifted); |
| } |
| return result; |
| } |
| return input; |
| } |
| |
| /// Push an input operand. If it is a float type, nothing to do. If it is |
| /// an integer type, then we need to also push its signdness (1 for signed, 0 |
| /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 |
| /// vector. We also need to convert bfloat inputs to i16 to account for the lack |
| /// of bfloat support in the WMMA intrinsics themselves. |
| static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, |
| Location loc, |
| const TypeConverter *typeConverter, |
| bool isUnsigned, Value llvmInput, |
| SmallVector<Value, 4> &operands) { |
| Type inputType = llvmInput.getType(); |
| auto vectorType = dyn_cast<VectorType>(inputType); |
| Type elemType = vectorType.getElementType(); |
| |
| if (elemType.isBF16()) |
| llvmInput = rewriter.create<LLVM::BitcastOp>( |
| loc, vectorType.clone(rewriter.getI16Type()), llvmInput); |
| if (!elemType.isInteger(8)) { |
| operands.push_back(llvmInput); |
| return; |
| } |
| |
| int64_t numBytes = vectorType.getNumElements(); |
| Type i32 = rewriter.getI32Type(); |
| VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); |
| auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); |
| |
| Value result = rewriter.createOrFold<LLVM::BitcastOp>( |
| loc, llvmVectorType32bits, llvmInput); |
| |
| // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag |
| bool localIsUnsigned = isUnsigned; |
| if (elemType.isUnsignedInteger(8)) { |
| localIsUnsigned = true; |
| } else if (elemType.isSignedInteger(8)) { |
| localIsUnsigned = false; |
| } |
| Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); |
| operands.push_back(sign); |
| operands.push_back(result); |
| } |
| |
| /// Push the output operand. For many cases this is only pushing the output in |
| /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, |
| /// since the same numbers of VGPRs is used, we need to decide if to store the |
| /// result in the upper 16 bits of the VGPRs or in the lower part. To store the |
| /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will |
| /// be stored it in the upper part |
| static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, |
| Location loc, |
| const TypeConverter *typeConverter, |
| Value output, int32_t subwordOffset, |
| bool clamp, SmallVector<Value, 4> &operands) { |
| Type inputType = output.getType(); |
| auto vectorType = dyn_cast<VectorType>(inputType); |
| Type elemType = vectorType.getElementType(); |
| if (elemType.isBF16()) |
| output = rewriter.create<LLVM::BitcastOp>( |
| loc, vectorType.clone(rewriter.getI16Type()), output); |
| operands.push_back(output); |
| if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { |
| operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); |
| } else if (elemType.isInteger(32)) { |
| operands.push_back(createI1Constant(rewriter, loc, clamp)); |
| } |
| } |
| |
| /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` |
| /// if one exists. This includes checking to ensure the intrinsic is supported |
| /// on the architecture you are compiling for. |
| static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, |
| Chipset chipset) { |
| uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), |
| b = mfma.getBlocks(); |
| Type sourceElem = mfma.getSourceA().getType(); |
| if (auto sourceType = dyn_cast<VectorType>(sourceElem)) |
| sourceElem = sourceType.getElementType(); |
| Type destElem = mfma.getDestC().getType(); |
| if (auto destType = dyn_cast<VectorType>(destElem)) |
| destElem = destType.getElementType(); |
| |
| if (sourceElem.isF32() && destElem.isF32()) { |
| if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) { |
| if (m == 32 && n == 32 && k == 4 && b == 1) |
| return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); |
| if (m == 16 && n == 16 && k == 8 && b == 1) |
| return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 1 && b == 2) |
| return ROCDL::mfma_f32_32x32x1f32::getOperationName(); |
| if (m == 16 && n == 16 && k == 1 && b == 4) |
| return ROCDL::mfma_f32_16x16x1f32::getOperationName(); |
| if (m == 4 && n == 4 && k == 1 && b == 16) |
| return ROCDL::mfma_f32_4x4x1f32::getOperationName(); |
| if (m == 32 && n == 32 && k == 2 && b == 1) |
| return ROCDL::mfma_f32_32x32x2f32::getOperationName(); |
| if (m == 16 && n == 16 && k == 4 && b == 1) |
| return ROCDL::mfma_f32_16x16x4f32::getOperationName(); |
| } |
| |
| if (sourceElem.isF16() && destElem.isF32()) { |
| if (m == 32 && n == 32 && k == 4 && b == 2) |
| return ROCDL::mfma_f32_32x32x4f16::getOperationName(); |
| if (m == 16 && n == 16 && k == 4 && b == 4) |
| return ROCDL::mfma_f32_16x16x4f16::getOperationName(); |
| if (m == 4 && n == 4 && k == 4 && b == 16) |
| return ROCDL::mfma_f32_4x4x4f16::getOperationName(); |
| if (m == 32 && n == 32 && k == 8 && b == 1) |
| return ROCDL::mfma_f32_32x32x8f16::getOperationName(); |
| if (m == 16 && n == 16 && k == 16 && b == 1) |
| return ROCDL::mfma_f32_16x16x16f16::getOperationName(); |
| } |
| |
| if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) { |
| if (m == 32 && n == 32 && k == 4 && b == 2) |
| return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); |
| if (m == 16 && n == 16 && k == 4 && b == 4) |
| return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); |
| if (m == 4 && n == 4 && k == 4 && b == 16) |
| return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); |
| if (m == 32 && n == 32 && k == 8 && b == 1) |
| return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); |
| if (m == 16 && n == 16 && k == 16 && b == 1) |
| return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); |
| } |
| |
| if (sourceElem.isBF16() && destElem.isF32()) { |
| if (m == 32 && n == 32 && k == 2 && b == 2) |
| return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); |
| if (m == 16 && n == 16 && k == 2 && b == 4) |
| return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); |
| if (m == 4 && n == 4 && k == 2 && b == 16) |
| return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); |
| if (m == 32 && n == 32 && k == 4 && b == 1) |
| return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); |
| if (m == 16 && n == 16 && k == 8 && b == 1) |
| return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); |
| } |
| |
| if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) { |
| if (m == 32 && n == 32 && k == 4 && b == 2) |
| return ROCDL::mfma_i32_32x32x4i8::getOperationName(); |
| if (m == 16 && n == 16 && k == 4 && b == 4) |
| return ROCDL::mfma_i32_16x16x4i8::getOperationName(); |
| if (m == 4 && n == 4 && k == 4 && b == 16) |
| return ROCDL::mfma_i32_4x4x4i8::getOperationName(); |
| if (m == 32 && n == 32 && k == 8 && b == 1) |
| return ROCDL::mfma_i32_32x32x8i8::getOperationName(); |
| if (m == 16 && n == 16 && k == 16 && b == 1) |
| return ROCDL::mfma_i32_16x16x16i8::getOperationName(); |
| if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40) |
| return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); |
| if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40) |
| return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); |
| } |
| |
| if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) { |
| if (m == 16 && n == 16 && k == 4 && b == 1) |
| return ROCDL::mfma_f64_16x16x4f64::getOperationName(); |
| if (m == 4 && n == 4 && k == 4 && b == 4) |
| return ROCDL::mfma_f64_4x4x4f64::getOperationName(); |
| } |
| |
| if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && |
| chipset.minorVersion >= 0x40) { |
| // Known to be correct because there are no scalar f8 instructions and |
| // because a length mismatch will have been caught by the verifier. |
| Type sourceBElem = |
| cast<VectorType>(mfma.getSourceB().getType()).getElementType(); |
| if (m == 16 && n == 16 && k == 32 && b == 1) { |
| if (sourceBElem.isFloat8E5M2FNUZ()) |
| return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); |
| if (sourceBElem.isFloat8E4M3FNUZ()) |
| return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 16 && b == 1) { |
| if (sourceBElem.isFloat8E5M2FNUZ()) |
| return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); |
| if (sourceBElem.isFloat8E4M3FNUZ()) |
| return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); |
| } |
| } |
| |
| if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && |
| chipset.minorVersion >= 0x40) { |
| Type sourceBElem = |
| cast<VectorType>(mfma.getSourceB().getType()).getElementType(); |
| if (m == 16 && n == 16 && k == 32 && b == 1) { |
| if (sourceBElem.isFloat8E5M2FNUZ()) |
| return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); |
| if (sourceBElem.isFloat8E4M3FNUZ()) |
| return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 16 && b == 1) { |
| if (sourceBElem.isFloat8E5M2FNUZ()) |
| return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); |
| if (sourceBElem.isFloat8E4M3FNUZ()) |
| return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); |
| } |
| } |
| |
| return std::nullopt; |
| } |
| |
| /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` |
| /// if one exists. This includes checking to ensure the intrinsic is supported |
| /// on the architecture you are compiling for. |
| static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, |
| Chipset chipset) { |
| auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType()); |
| auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType()); |
| auto elemSourceType = sourceVectorType.getElementType(); |
| auto elemDestType = destVectorType.getElementType(); |
| |
| if (elemSourceType.isF16() && elemDestType.isF32()) { |
| return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); |
| } |
| if (elemSourceType.isBF16() && elemDestType.isF32()) { |
| return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); |
| } else if (elemSourceType.isF16() && elemDestType.isF16()) { |
| return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); |
| } else if (elemSourceType.isBF16() && elemDestType.isBF16()) { |
| return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); |
| } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) { |
| return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); |
| } |
| return std::nullopt; |
| } |
| |
| namespace { |
| struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { |
| MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Type outType = typeConverter->convertType(op.getDestD().getType()); |
| Type intrinsicOutType = outType; |
| if (auto outVecType = dyn_cast<VectorType>(outType)) |
| if (outVecType.getElementType().isBF16()) |
| intrinsicOutType = outVecType.clone(rewriter.getI16Type()); |
| |
| if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08) |
| return op->emitOpError("MFMA only supported on gfx908+"); |
| uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp()); |
| if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { |
| if (chipset.minorVersion < 0x40) |
| return op.emitOpError("negation unsupported on older than gfx840"); |
| getBlgpField |= |
| op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); |
| } |
| std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); |
| if (!maybeIntrinsic.has_value()) |
| return op.emitOpError("no intrinsic matching MFMA size on given chipset"); |
| OperationState loweredOp(loc, *maybeIntrinsic); |
| loweredOp.addTypes(intrinsicOutType); |
| loweredOp.addOperands( |
| {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()), |
| mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()), |
| adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()), |
| createI32Constant(rewriter, loc, op.getAbid()), |
| createI32Constant(rewriter, loc, getBlgpField)}); |
| Value lowered = rewriter.create(loweredOp)->getResult(0); |
| if (outType != intrinsicOutType) |
| lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered); |
| rewriter.replaceOp(op, lowered); |
| return success(); |
| } |
| }; |
| |
| struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { |
| WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Type outType = typeConverter->convertType(op.getDestD().getType()); |
| |
| if (chipset.majorVersion != 11) |
| return op->emitOpError("WMMA only supported on gfx11"); |
| |
| std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); |
| |
| if (!maybeIntrinsic.has_value()) |
| return op.emitOpError("no intrinsic matching WMMA on the given chipset"); |
| |
| OperationState loweredOp(loc, *maybeIntrinsic); |
| loweredOp.addTypes(outType); |
| |
| SmallVector<Value, 4> operands; |
| wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), |
| adaptor.getSourceA(), operands); |
| wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), |
| adaptor.getSourceB(), operands); |
| wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), |
| op.getSubwordOffset(), op.getClamp(), operands); |
| |
| loweredOp.addOperands(operands); |
| Operation *lowered = rewriter.create(loweredOp); |
| rewriter.replaceOp(op, lowered->getResults()); |
| |
| return success(); |
| } |
| }; |
| |
| namespace { |
| struct ExtPackedFp8OpLowering final |
| : public ConvertOpToLLVMPattern<ExtPackedFp8Op> { |
| ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter), |
| chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| struct PackedTrunc2xFp8OpLowering final |
| : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { |
| PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter), |
| chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| struct PackedStochRoundFp8OpLowering final |
| : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> { |
| PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter), |
| chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(PackedStochRoundFp8Op op, |
| PackedStochRoundFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| } // end namespace |
| |
| LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( |
| ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) |
| return rewriter.notifyMatchFailure( |
| loc, "Fp8 conversion instructions are not available on target " |
| "architecture and their emulation is not implemented"); |
| Type v4i8 = |
| getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); |
| Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| Type f32 = getTypeConverter()->convertType(op.getResult().getType()); |
| |
| Value source = adaptor.getSource(); |
| auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType()); |
| Type sourceElemType = getElementTypeOrSelf(op.getSource()); |
| // Extend to a v4i8 |
| if (!sourceVecType || sourceVecType.getNumElements() < 4) { |
| Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8); |
| if (!sourceVecType) { |
| longVec = rewriter.create<LLVM::InsertElementOp>( |
| loc, longVec, source, createI32Constant(rewriter, loc, 0)); |
| } else { |
| for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { |
| Value idx = createI32Constant(rewriter, loc, i); |
| Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); |
| longVec = |
| rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); |
| } |
| } |
| source = longVec; |
| } |
| Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); |
| Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); |
| if (sourceElemType.isFloat8E5M2FNUZ()) { |
| rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source, |
| wordSel); |
| } else if (sourceElemType.isFloat8E4M3FNUZ()) { |
| rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source, |
| wordSel); |
| } |
| return success(); |
| } |
| |
| LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( |
| PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) |
| return rewriter.notifyMatchFailure( |
| loc, "Fp8 conversion instructions are not available on target " |
| "architecture and their emulation is not implemented"); |
| Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| |
| Type resultType = op.getResult().getType(); |
| Type resultElemType = getElementTypeOrSelf(resultType); |
| |
| Value sourceA = adaptor.getSourceA(); |
| Value sourceB = adaptor.getSourceB(); |
| if (!sourceB) |
| sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType()); |
| Value existing = adaptor.getExisting(); |
| if (existing) |
| existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); |
| else |
| existing = rewriter.create<LLVM::UndefOp>(loc, i32); |
| Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); |
| |
| Value result; |
| if (resultElemType.isFloat8E5M2FNUZ()) |
| result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, |
| existing, wordSel); |
| else if (resultElemType.isFloat8E4M3FNUZ()) |
| result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, |
| existing, wordSel); |
| |
| result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
| op, getTypeConverter()->convertType(resultType), result); |
| return success(); |
| } |
| |
| LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( |
| PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) |
| return rewriter.notifyMatchFailure( |
| loc, "Fp8 conversion instructions are not available on target " |
| "architecture and their emulation is not implemented"); |
| Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| |
| Type resultType = op.getResult().getType(); |
| Type resultElemType = getElementTypeOrSelf(resultType); |
| |
| Value source = adaptor.getSource(); |
| Value stoch = adaptor.getStochiasticParam(); |
| Value existing = adaptor.getExisting(); |
| if (existing) |
| existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); |
| else |
| existing = rewriter.create<LLVM::UndefOp>(loc, i32); |
| Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); |
| |
| Value result; |
| if (resultElemType.isFloat8E5M2FNUZ()) |
| result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch, |
| existing, byteSel); |
| else if (resultElemType.isFloat8E4M3FNUZ()) |
| result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch, |
| existing, byteSel); |
| |
| result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
| op, getTypeConverter()->convertType(resultType), result); |
| return success(); |
| } |
| |
| struct ConvertAMDGPUToROCDLPass |
| : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> { |
| ConvertAMDGPUToROCDLPass() = default; |
| |
| void runOnOperation() override { |
| MLIRContext *ctx = &getContext(); |
| FailureOr<Chipset> maybeChipset = Chipset::parse(chipset); |
| if (failed(maybeChipset)) { |
| emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); |
| return signalPassFailure(); |
| } |
| |
| RewritePatternSet patterns(ctx); |
| LLVMTypeConverter converter(ctx); |
| populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); |
| LLVMConversionTarget target(getContext()); |
| target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); |
| target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
| target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); |
| if (failed(applyPartialConversion(getOperation(), target, |
| std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, |
| RewritePatternSet &patterns, |
| Chipset chipset) { |
| converter.addConversion([](BFloat16Type t) -> Type { |
| return IntegerType::get(t.getContext(), 16); |
| }); |
| converter.addConversion([&converter](VectorType t) -> std::optional<Type> { |
| if (!t.getElementType().isBF16()) |
| return std::nullopt; |
| return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16))); |
| }); |
| |
| patterns |
| .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, |
| RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, |
| RawBufferOpLowering<RawBufferAtomicFaddOp, |
| ROCDL::RawPtrBufferAtomicFaddOp>, |
| RawBufferOpLowering<RawBufferAtomicFmaxOp, |
| ROCDL::RawPtrBufferAtomicFmaxOp>, |
| RawBufferOpLowering<RawBufferAtomicSmaxOp, |
| ROCDL::RawPtrBufferAtomicSmaxOp>, |
| RawBufferOpLowering<RawBufferAtomicUminOp, |
| ROCDL::RawPtrBufferAtomicUminOp>, |
| RawBufferOpLowering<RawBufferAtomicCmpswapOp, |
| ROCDL::RawPtrBufferAtomicCmpSwap>, |
| LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering, |
| WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering, |
| PackedStochRoundFp8OpLowering>(converter, chipset); |
| } |
| |
| std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() { |
| return std::make_unique<ConvertAMDGPUToROCDLPass>(); |
| } |