| //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" |
| |
| #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Pass/Pass.h" |
| |
| #include "../LLVMCommon/MemRefDescriptor.h" |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/ErrorHandling.h" |
| #include <optional> |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS |
| #include "mlir/Conversion/Passes.h.inc" |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace mlir::amdgpu; |
| |
| // Define commonly used chipsets versions for convenience. |
| constexpr Chipset kGfx908 = Chipset(9, 0, 8); |
| constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); |
| constexpr Chipset kGfx942 = Chipset(9, 4, 2); |
| constexpr Chipset kGfx950 = Chipset(9, 5, 0); |
| |
| /// Convert an unsigned number `val` to i32. |
| static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, |
| Location loc, Value val) { |
| IntegerType i32 = rewriter.getI32Type(); |
| // Force check that `val` is of int type. |
| auto valTy = cast<IntegerType>(val.getType()); |
| if (i32 == valTy) |
| return val; |
| return valTy.getWidth() > 32 |
| ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val)) |
| : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val)); |
| } |
| |
| static Value createI32Constant(ConversionPatternRewriter &rewriter, |
| Location loc, int32_t value) { |
| Type i32 = rewriter.getI32Type(); |
| return LLVM::ConstantOp::create(rewriter, loc, i32, value); |
| } |
| |
| static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, |
| bool value) { |
| Type llvmI1 = rewriter.getI1Type(); |
| return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); |
| } |
| |
| /// Returns the linear index used to access an element in the memref. |
| static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, |
| Location loc, MemRefDescriptor &memRefDescriptor, |
| ValueRange indices, ArrayRef<int64_t> strides) { |
| IntegerType i32 = rewriter.getI32Type(); |
| Value index; |
| for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) { |
| if (stride != 1) { // Skip if stride is 1. |
| Value strideValue = |
| ShapedType::isDynamic(stride) |
| ? convertUnsignedToI32(rewriter, loc, |
| memRefDescriptor.stride(rewriter, loc, i)) |
| : LLVM::ConstantOp::create(rewriter, loc, i32, stride); |
| increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue); |
| } |
| index = index ? LLVM::AddOp::create(rewriter, loc, index, increment) |
| : increment; |
| } |
| return index ? index : createI32Constant(rewriter, loc, 0); |
| } |
| |
| /// Compute the contents of the `num_records` field for a given memref |
| /// descriptor - that is, the number of bytes that's one element past the |
| /// greatest possible valid index into the memref. |
| static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, |
| MemRefType memrefType, |
| MemRefDescriptor &memrefDescriptor, |
| ArrayRef<int64_t> strides, |
| uint32_t elementByteWidth) { |
| if (memrefType.hasStaticShape() && |
| !llvm::any_of(strides, ShapedType::isDynamic)) { |
| int64_t size = memrefType.getRank() == 0 ? 1 : 0; |
| ArrayRef<int64_t> shape = memrefType.getShape(); |
| for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) |
| size = std::max(shape[i] * strides[i], size); |
| size = size * elementByteWidth; |
| assert(size < std::numeric_limits<uint32_t>::max() && |
| "the memref buffer is too large"); |
| return createI32Constant(rewriter, loc, static_cast<int32_t>(size)); |
| } |
| Value maxIndex; |
| for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { |
| Value size = memrefDescriptor.size(rewriter, loc, i); |
| Value stride = memrefDescriptor.stride(rewriter, loc, i); |
| Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride); |
| maxIndex = maxIndex |
| ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim) |
| : maxThisDim; |
| } |
| Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex); |
| Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); |
| return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst); |
| } |
| |
| static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, |
| Value basePointer, Value numRecords, |
| bool boundsCheck, amdgpu::Chipset chipset, |
| Value cacheSwizzleStride = nullptr, |
| unsigned addressSpace = 8) { |
| // The stride value is generally 0. However, on MI-300 and onward, you can |
| // enable a cache swizzling mode by setting bit 14 of the stride field |
| // and setting that stride to a cache stride. |
| Type i16 = rewriter.getI16Type(); |
| Value stride; |
| if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { |
| Value cacheStrideZext = |
| LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride); |
| Value swizzleBit = LLVM::ConstantOp::create( |
| rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14)); |
| stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit, |
| /*isDisjoint=*/true); |
| } else { |
| stride = LLVM::ConstantOp::create(rewriter, loc, i16, |
| rewriter.getI16IntegerAttr(0)); |
| } |
| // Get the number of elements. |
| // Flag word: |
| // bits 0-11: dst sel, ignored by these intrinsics |
| // bits 12-14: data format (ignored, must be nonzero, 7=float) |
| // bits 15-18: data format (ignored, must be nonzero, 4=32bit) |
| // bit 19: In nested heap (0 here) |
| // bit 20: Behavior on unmap (0 means "return 0 / ignore") |
| // bits 21-22: Index stride for swizzles (N/A) |
| // bit 23: Add thread ID (0) |
| // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) |
| // bits 25-26: Reserved (0) |
| // bit 27: Buffer is non-volatile (CDNA only) |
| // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = |
| // none, 3 = either swizzles or testing against offset field) RDNA only |
| // bits 30-31: Type (must be 0) |
| uint32_t flags = (7 << 12) | (4 << 15); |
| if (chipset.majorVersion >= 10) { |
| flags |= (1 << 24); |
| uint32_t oob = boundsCheck ? 3 : 2; |
| flags |= (oob << 28); |
| } |
| Value flagsConst = createI32Constant(rewriter, loc, flags); |
| Type rsrcType = |
| LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); |
| Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>( |
| loc, rsrcType, basePointer, stride, numRecords, flagsConst); |
| return resource; |
| } |
| |
| namespace { |
| struct FatRawBufferCastLowering |
| : public ConvertOpToLLVMPattern<FatRawBufferCastOp> { |
| FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter), |
| chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Value memRef = adaptor.getSource(); |
| Value unconvertedMemref = op.getSource(); |
| MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType()); |
| MemRefDescriptor descriptor(memRef); |
| |
| DataLayout dataLayout = DataLayout::closest(op); |
| int64_t elementByteWidth = |
| dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; |
| |
| int64_t unusedOffset = 0; |
| SmallVector<int64_t, 5> strideVals; |
| if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset))) |
| return op.emitOpError("Can't lower non-stride-offset memrefs"); |
| |
| Value numRecords = adaptor.getValidBytes(); |
| if (!numRecords) |
| numRecords = getNumRecords(rewriter, loc, memrefType, descriptor, |
| strideVals, elementByteWidth); |
| |
| Value basePointer = |
| adaptor.getResetOffset() |
| ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), |
| memrefType) |
| : descriptor.alignedPtr(rewriter, loc); |
| |
| Value offset = adaptor.getResetOffset() |
| ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(), |
| rewriter.getIndexAttr(0)) |
| : descriptor.offset(rewriter, loc); |
| |
| bool hasSizes = memrefType.getRank() > 0; |
| // No need to unpack() and pack() all the individual sizes and strides, |
| // so we'll just extract the arrays. |
| Value sizes = hasSizes |
| ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, |
| kSizePosInMemRefDescriptor) |
| : Value{}; |
| Value strides = |
| hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, |
| kStridePosInMemRefDescriptor) |
| : Value{}; |
| |
| Value fatPtr = makeBufferRsrc( |
| rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), |
| chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7); |
| |
| Value result = MemRefDescriptor::poison( |
| rewriter, loc, |
| getTypeConverter()->convertType(op.getResult().getType())); |
| SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor}; |
| result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos); |
| result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, |
| kAlignedPtrPosInMemRefDescriptor); |
| result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, |
| kOffsetPosInMemRefDescriptor); |
| if (hasSizes) { |
| result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes, |
| kSizePosInMemRefDescriptor); |
| result = LLVM::InsertValueOp::create(rewriter, loc, result, strides, |
| kStridePosInMemRefDescriptor); |
| } |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| /// Define lowering patterns for raw buffer ops |
| template <typename GpuOp, typename Intrinsic> |
| struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { |
| RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| static constexpr uint32_t maxVectorOpWidth = 128; |
| |
| LogicalResult |
| matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = gpuOp.getLoc(); |
| Value memref = adaptor.getMemref(); |
| Value unconvertedMemref = gpuOp.getMemref(); |
| MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType()); |
| |
| if (chipset.majorVersion < 9) |
| return gpuOp.emitOpError("raw buffer ops require GCN or higher"); |
| |
| Value storeData = adaptor.getODSOperands(0)[0]; |
| if (storeData == memref) // no write component to this op |
| storeData = Value(); |
| Type wantedDataType; |
| if (storeData) |
| wantedDataType = storeData.getType(); |
| else |
| wantedDataType = gpuOp.getODSResults(0)[0].getType(); |
| |
| Value atomicCmpData = Value(); |
| // Operand index 1 of a load is the indices, trying to read them can crash. |
| if (storeData) { |
| Value maybeCmpData = adaptor.getODSOperands(1)[0]; |
| if (maybeCmpData != memref) |
| atomicCmpData = maybeCmpData; |
| } |
| |
| Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); |
| |
| Type i32 = rewriter.getI32Type(); |
| |
| // Get the type size in bytes. |
| DataLayout dataLayout = DataLayout::closest(gpuOp); |
| int64_t elementByteWidth = |
| dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; |
| Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); |
| |
| // If we want to load a vector<NxT> with total size <= 32 |
| // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 |
| // and the total load size is >= 32, use a vector load of N / (bitsize(T) / |
| // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, |
| // so bitcast any floats to integers. |
| Type llvmBufferValType = llvmWantedDataType; |
| if (atomicCmpData) { |
| if (auto floatType = dyn_cast<FloatType>(wantedDataType)) |
| llvmBufferValType = this->getTypeConverter()->convertType( |
| rewriter.getIntegerType(floatType.getWidth())); |
| } |
| if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) { |
| uint32_t vecLen = dataVector.getNumElements(); |
| uint32_t elemBits = |
| dataLayout.getTypeSizeInBits(dataVector.getElementType()); |
| uint32_t totalBits = elemBits * vecLen; |
| bool usePackedFp16 = |
| isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2; |
| if (totalBits > maxVectorOpWidth) |
| return gpuOp.emitOpError( |
| "Total width of loads or stores must be no more than " + |
| Twine(maxVectorOpWidth) + " bits, but we call for " + |
| Twine(totalBits) + |
| " bits. This should've been caught in validation"); |
| if (!usePackedFp16 && elemBits < 32) { |
| if (totalBits > 32) { |
| if (totalBits % 32 != 0) |
| return gpuOp.emitOpError("Load or store of more than 32-bits that " |
| "doesn't fit into words. Can't happen\n"); |
| llvmBufferValType = this->typeConverter->convertType( |
| VectorType::get(totalBits / 32, i32)); |
| } else { |
| llvmBufferValType = this->typeConverter->convertType( |
| rewriter.getIntegerType(totalBits)); |
| } |
| } |
| } |
| if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) { |
| // Buffer intrinsics doesn't support 1-element vectors, cast them to |
| // scalars. |
| if (vecType.getNumElements() == 1) |
| llvmBufferValType = vecType.getElementType(); |
| } |
| |
| SmallVector<Value, 6> args; |
| if (storeData) { |
| if (llvmBufferValType != llvmWantedDataType) { |
| Value castForStore = LLVM::BitcastOp::create( |
| rewriter, loc, llvmBufferValType, storeData); |
| args.push_back(castForStore); |
| } else { |
| args.push_back(storeData); |
| } |
| } |
| |
| if (atomicCmpData) { |
| if (llvmBufferValType != llvmWantedDataType) { |
| Value castForCmp = LLVM::BitcastOp::create( |
| rewriter, loc, llvmBufferValType, atomicCmpData); |
| args.push_back(castForCmp); |
| } else { |
| args.push_back(atomicCmpData); |
| } |
| } |
| |
| // Construct buffer descriptor from memref, attributes |
| int64_t offset = 0; |
| SmallVector<int64_t, 5> strides; |
| if (failed(memrefType.getStridesAndOffset(strides, offset))) |
| return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); |
| |
| MemRefDescriptor memrefDescriptor(memref); |
| |
| Value ptr = memrefDescriptor.bufferPtr( |
| rewriter, loc, *this->getTypeConverter(), memrefType); |
| Value numRecords = getNumRecords( |
| rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth); |
| Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords, |
| adaptor.getBoundsCheck(), chipset); |
| args.push_back(resource); |
| |
| // Indexing (voffset) |
| Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor, |
| adaptor.getIndices(), strides); |
| if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset(); |
| indexOffset && *indexOffset > 0) { |
| Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); |
| voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset, |
| extraOffsetConst) |
| : extraOffsetConst; |
| } |
| voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst); |
| args.push_back(voffset); |
| |
| // SGPR offset. |
| Value sgprOffset = adaptor.getSgprOffset(); |
| if (!sgprOffset) |
| sgprOffset = createI32Constant(rewriter, loc, 0); |
| sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst); |
| args.push_back(sgprOffset); |
| |
| // bit 0: GLC = 0 (atomics drop value, less coherency) |
| // bits 1-2: SLC, DLC = 0 (similarly) |
| // bit 3: swizzled (0 for raw) |
| args.push_back(createI32Constant(rewriter, loc, 0)); |
| |
| llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), |
| llvmBufferValType); |
| Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args, |
| ArrayRef<NamedAttribute>()); |
| if (lowered->getNumResults() == 1) { |
| Value replacement = lowered->getResult(0); |
| if (llvmBufferValType != llvmWantedDataType) { |
| replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType, |
| replacement); |
| } |
| rewriter.replaceOp(gpuOp, replacement); |
| } else { |
| rewriter.eraseOp(gpuOp); |
| } |
| return success(); |
| } |
| }; |
| |
| // TODO: AMDGPU backend already have all this bitpacking logic, we should move |
| // it to some common place. |
| /// Vmcnt, Expcnt and Lgkmcnt are decoded as follows: |
| /// Vmcnt = Waitcnt[3:0] (pre-gfx9) |
| /// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10) |
| /// Vmcnt = Waitcnt[15:10] (gfx11) |
| /// Expcnt = Waitcnt[6:4] (pre-gfx11) |
| /// Expcnt = Waitcnt[2:0] (gfx11) |
| /// Lgkmcnt = Waitcnt[11:8] (pre-gfx10) |
| /// Lgkmcnt = Waitcnt[13:8] (gfx10) |
| /// Lgkmcnt = Waitcnt[9:4] (gfx11) |
| static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt, |
| unsigned expcnt, unsigned lgkmcnt) { |
| if (chipset.majorVersion < 9) { |
| vmcnt = std::min(15u, vmcnt); |
| expcnt = std::min(7u, expcnt); |
| lgkmcnt = std::min(15u, lgkmcnt); |
| return vmcnt | (expcnt << 4) | (lgkmcnt << 8); |
| } |
| if (chipset.majorVersion == 9) { |
| vmcnt = std::min(63u, vmcnt); |
| expcnt = std::min(7u, expcnt); |
| lgkmcnt = std::min(15u, lgkmcnt); |
| unsigned lowBits = vmcnt & 0xF; |
| unsigned highBits = (vmcnt >> 4) << 14; |
| unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); |
| return lowBits | highBits | otherCnts; |
| } |
| if (chipset.majorVersion == 10) { |
| vmcnt = std::min(63u, vmcnt); |
| expcnt = std::min(7u, expcnt); |
| lgkmcnt = std::min(63u, lgkmcnt); |
| unsigned lowBits = vmcnt & 0xF; |
| unsigned highBits = (vmcnt >> 4) << 14; |
| unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); |
| return lowBits | highBits | otherCnts; |
| } |
| if (chipset.majorVersion == 11) { |
| vmcnt = std::min(63u, vmcnt); |
| expcnt = std::min(7u, expcnt); |
| lgkmcnt = std::min(63u, lgkmcnt); |
| return (vmcnt << 10) | expcnt | (lgkmcnt << 4); |
| } |
| return failure(); |
| } |
| |
| struct MemoryCounterWaitOpLowering |
| : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> { |
| MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter, |
| Chipset chipset) |
| : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter), |
| chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (chipset.majorVersion >= 12) { |
| Location loc = op.getLoc(); |
| if (std::optional<int> ds = adaptor.getDs()) |
| ROCDL::WaitDscntOp::create(rewriter, loc, *ds); |
| |
| if (std::optional<int> load = adaptor.getLoad()) |
| ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); |
| |
| if (std::optional<int> store = adaptor.getStore()) |
| ROCDL::WaitStorecntOp::create(rewriter, loc, *store); |
| |
| if (std::optional<int> exp = adaptor.getExp()) |
| ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); |
| |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| |
| auto getVal = [](Attribute attr) -> unsigned { |
| if (attr) |
| return cast<IntegerAttr>(attr).getInt(); |
| |
| // This value will be clamped to the maximum value for the chipset. |
| return 1024; |
| }; |
| unsigned ds = getVal(adaptor.getDsAttr()); |
| unsigned exp = getVal(adaptor.getExpAttr()); |
| |
| unsigned vmcnt = 1024; |
| Attribute load = adaptor.getLoadAttr(); |
| Attribute store = adaptor.getStoreAttr(); |
| if (load && store) { |
| vmcnt = getVal(load) + getVal(store); |
| } else if (load) { |
| vmcnt = getVal(load); |
| } else if (store) { |
| vmcnt = getVal(store); |
| } |
| |
| FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds); |
| if (failed(waitcnt)) |
| return op.emitOpError("unsupported chipset"); |
| |
| rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt); |
| return success(); |
| } |
| }; |
| |
| struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { |
| LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11; |
| |
| if (requiresInlineAsm) { |
| auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), |
| LLVM::AsmDialect::AD_ATT); |
| const char *asmStr = |
| ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier"; |
| const char *constraints = ""; |
| rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>( |
| op, |
| /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), |
| /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, |
| /*is_align_stack=*/false, LLVM::TailCallKind::None, |
| /*asm_dialect=*/asmDialectAttr, |
| /*operand_attrs=*/ArrayAttr()); |
| return success(); |
| } |
| if (chipset.majorVersion < 12) { |
| constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8); |
| constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8); |
| // Left in place in case someone disables the inline ASM path or future |
| // chipsets use the same bit pattern. |
| constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4); |
| |
| int32_t ldsOnlyBits; |
| if (chipset.majorVersion == 11) |
| ldsOnlyBits = ldsOnlyBitsGfx11; |
| else if (chipset.majorVersion == 10) |
| ldsOnlyBits = ldsOnlyBitsGfx10; |
| else if (chipset.majorVersion <= 9) |
| ldsOnlyBits = ldsOnlyBitsGfx6789; |
| else |
| return op.emitOpError( |
| "don't know how to lower this for chipset major version") |
| << chipset.majorVersion; |
| |
| Location loc = op->getLoc(); |
| ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits); |
| rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op); |
| } else { |
| Location loc = op->getLoc(); |
| ROCDL::WaitDscntOp::create(rewriter, loc, 0); |
| ROCDL::BarrierSignalOp::create(rewriter, loc, -1); |
| rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> { |
| SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op, |
| (uint32_t)op.getOpts()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL |
| /// and LLVM AMDGPU intrinsics convention. |
| /// |
| /// Specifically: |
| /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic |
| /// allows bf16. Newer MFMAs support bf16 types on operand, check |
| /// IntrinsicsAMDGPU.td file for reference. |
| /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32> |
| /// instead, which is what the f8f6f4 intrinsics use. |
| /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit |
| /// integer. |
| /// |
| /// Note that the type of `input` has already been LLVM type converted: |
| /// therefore 8-bit and smaller floats are represented as their corresponding |
| /// `iN` integers. |
| static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, |
| Location loc, Value input, |
| bool allowBf16 = true) { |
| Type inputType = input.getType(); |
| if (auto vectorType = dyn_cast<VectorType>(inputType)) { |
| if (vectorType.getElementType().isBF16() && !allowBf16) |
| return LLVM::BitcastOp::create( |
| rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); |
| if (vectorType.getElementType().isInteger(8) && |
| vectorType.getNumElements() <= 8) |
| return LLVM::BitcastOp::create( |
| rewriter, loc, |
| rewriter.getIntegerType(vectorType.getNumElements() * 8), input); |
| if (isa<IntegerType>(vectorType.getElementType()) && |
| vectorType.getElementTypeBitWidth() <= 8) { |
| int64_t numWords = llvm::divideCeil( |
| vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), |
| 32); |
| return LLVM::BitcastOp::create( |
| rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), |
| input); |
| } |
| } |
| return input; |
| } |
| |
| /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU |
| /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention. |
| /// |
| /// Specifically: |
| /// 1. If `input` is a i8 value, zero extend it to i32 |
| /// 2. If `input` is a vector of length 4 and type i8, cast it to i32 |
| /// |
| /// Note that the type of `input` has already been LLVM type converted: |
| /// therefore 8-bit and smaller floats are represented as their corresponding |
| /// `iN` integers. |
| static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, |
| Location loc, Value input) { |
| Type inputType = input.getType(); |
| Type outputType = rewriter.getI32Type(); |
| if (auto intType = dyn_cast<IntegerType>(inputType)) |
| return LLVM::ZExtOp::create(rewriter, loc, outputType, input); |
| return LLVM::BitcastOp::create(rewriter, loc, outputType, input); |
| } |
| |
| /// Push an input operand. If it is a float type, nothing to do. If it is |
| /// an integer type, then we need to also push its signdness (1 for signed, 0 |
| /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 |
| /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+). |
| /// We also need to convert bfloat inputs to i16 to account for the bfloat |
| /// intrinsics having been defined before the AMD backend supported bfloat. We |
| /// similarly need to pack 8-bit float types into integers as if they were i8 |
| /// (which they are for the backend's purposes). |
| static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, |
| Location loc, |
| const TypeConverter *typeConverter, |
| bool isUnsigned, Value llvmInput, |
| Value mlirInput, |
| SmallVector<Value, 4> &operands) { |
| Type inputType = llvmInput.getType(); |
| auto vectorType = dyn_cast<VectorType>(inputType); |
| if (!vectorType) { |
| operands.push_back(llvmInput); |
| return; |
| } |
| Type elemType = vectorType.getElementType(); |
| |
| if (elemType.isBF16()) |
| llvmInput = LLVM::BitcastOp::create( |
| rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); |
| if (elemType.getIntOrFloatBitWidth() > 8) { |
| operands.push_back(llvmInput); |
| return; |
| } |
| |
| // We need to check the type of the input before conversion to properly test |
| // for int8. This is because, in LLVM, fp8 type is converted to int8, so the |
| // fp8/int8 information is lost during the conversion process. |
| auto mlirInputType = cast<VectorType>(mlirInput.getType()); |
| bool isInputInteger = mlirInputType.getElementType().isInteger(); |
| if (isInputInteger) { |
| // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag |
| bool localIsUnsigned = isUnsigned; |
| if (elemType.isUnsignedInteger()) { |
| localIsUnsigned = true; |
| } else if (elemType.isSignedInteger()) { |
| localIsUnsigned = false; |
| } |
| Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); |
| operands.push_back(sign); |
| } |
| |
| int64_t numBits = |
| vectorType.getNumElements() * elemType.getIntOrFloatBitWidth(); |
| Type i32 = rewriter.getI32Type(); |
| Type intrinsicInType = numBits <= 32 |
| ? (Type)rewriter.getIntegerType(numBits) |
| : (Type)VectorType::get(numBits / 32, i32); |
| auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType); |
| Value castInput = rewriter.createOrFold<LLVM::BitcastOp>( |
| loc, llvmIntrinsicInType, llvmInput); |
| // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need |
| // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. |
| // Add in the zeros here. |
| if (numBits < 32) |
| castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput); |
| operands.push_back(castInput); |
| } |
| |
| /// Push the output operand. For many cases this is only pushing the output in |
| /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, |
| /// since the same numbers of VGPRs is used, we need to decide if to store the |
| /// result in the upper 16 bits of the VGPRs or in the lower part. To store the |
| /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will |
| /// be stored it in the upper part. The subwordOffset must not be set for gfx12, |
| /// as the instructions have been changed to return fewer registers instead. |
| static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, |
| Location loc, |
| const TypeConverter *typeConverter, |
| Value output, int32_t subwordOffset, |
| bool clamp, SmallVector<Value, 4> &operands) { |
| Type inputType = output.getType(); |
| auto vectorType = dyn_cast<VectorType>(inputType); |
| Type elemType = vectorType.getElementType(); |
| if (elemType.isBF16()) |
| output = LLVM::BitcastOp::create( |
| rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); |
| operands.push_back(output); |
| if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { |
| operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); |
| } else if (elemType.isInteger(32)) { |
| operands.push_back(createI1Constant(rewriter, loc, clamp)); |
| } |
| } |
| |
| /// Return true if `type` is the E5M2 variant of an 8-bit float that is |
| /// supported by the `_bf8` instructions on the given `chipset`. |
| static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) { |
| return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) || |
| (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type)); |
| } |
| |
| /// Return true if `type` is the E4M3FN variant of an 8-bit float that is |
| /// supported by the `_fp8` instructions on the given `chipset`. |
| static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) { |
| return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) || |
| (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type)); |
| } |
| |
| /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` |
| /// if one exists. This includes checking to ensure the intrinsic is supported |
| /// on the architecture you are compiling for. |
| static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, |
| Chipset chipset) { |
| uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), |
| b = mfma.getBlocks(); |
| Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType()); |
| Type destElem = getElementTypeOrSelf(mfma.getDestC().getType()); |
| |
| if (sourceElem.isF32() && destElem.isF32()) { |
| if (mfma.getReducePrecision() && chipset >= kGfx942) { |
| if (m == 32 && n == 32 && k == 4 && b == 1) |
| return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); |
| if (m == 16 && n == 16 && k == 8 && b == 1) |
| return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 1 && b == 2) |
| return ROCDL::mfma_f32_32x32x1f32::getOperationName(); |
| if (m == 16 && n == 16 && k == 1 && b == 4) |
| return ROCDL::mfma_f32_16x16x1f32::getOperationName(); |
| if (m == 4 && n == 4 && k == 1 && b == 16) |
| return ROCDL::mfma_f32_4x4x1f32::getOperationName(); |
| if (m == 32 && n == 32 && k == 2 && b == 1) |
| return ROCDL::mfma_f32_32x32x2f32::getOperationName(); |
| if (m == 16 && n == 16 && k == 4 && b == 1) |
| return ROCDL::mfma_f32_16x16x4f32::getOperationName(); |
| } |
| |
| if (sourceElem.isF16() && destElem.isF32()) { |
| if (chipset >= kGfx950) { |
| if (m == 32 && n == 32 && k == 16 && b == 1) |
| return ROCDL::mfma_f32_32x32x16_f16::getOperationName(); |
| if (m == 16 && n == 16 && k == 32 && b == 1) |
| return ROCDL::mfma_f32_16x16x32_f16::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 4 && b == 2) |
| return ROCDL::mfma_f32_32x32x4f16::getOperationName(); |
| if (m == 16 && n == 16 && k == 4 && b == 4) |
| return ROCDL::mfma_f32_16x16x4f16::getOperationName(); |
| if (m == 4 && n == 4 && k == 4 && b == 16) |
| return ROCDL::mfma_f32_4x4x4f16::getOperationName(); |
| if (m == 32 && n == 32 && k == 8 && b == 1) |
| return ROCDL::mfma_f32_32x32x8f16::getOperationName(); |
| if (m == 16 && n == 16 && k == 16 && b == 1) |
| return ROCDL::mfma_f32_16x16x16f16::getOperationName(); |
| } |
| |
| if (sourceElem.isBF16() && destElem.isF32()) { |
| if (chipset >= kGfx950) { |
| if (m == 32 && n == 32 && k == 16 && b == 1) |
| return ROCDL::mfma_f32_32x32x16_bf16::getOperationName(); |
| if (m == 16 && n == 16 && k == 32 && b == 1) |
| return ROCDL::mfma_f32_16x16x32_bf16::getOperationName(); |
| } |
| if (chipset >= kGfx90a) { |
| if (m == 32 && n == 32 && k == 4 && b == 2) |
| return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); |
| if (m == 16 && n == 16 && k == 4 && b == 4) |
| return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); |
| if (m == 4 && n == 4 && k == 4 && b == 16) |
| return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); |
| if (m == 32 && n == 32 && k == 8 && b == 1) |
| return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); |
| if (m == 16 && n == 16 && k == 16 && b == 1) |
| return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 2 && b == 2) |
| return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); |
| if (m == 16 && n == 16 && k == 2 && b == 4) |
| return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); |
| if (m == 4 && n == 4 && k == 2 && b == 16) |
| return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); |
| if (m == 32 && n == 32 && k == 4 && b == 1) |
| return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); |
| if (m == 16 && n == 16 && k == 8 && b == 1) |
| return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); |
| } |
| |
| if (sourceElem.isInteger(8) && destElem.isInteger(32)) { |
| if (chipset >= kGfx950) { |
| if (m == 32 && n == 32 && k == 32 && b == 1) |
| return ROCDL::mfma_i32_32x32x32_i8::getOperationName(); |
| if (m == 16 && n == 16 && k == 64 && b == 1) |
| return ROCDL::mfma_i32_16x16x64_i8::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 4 && b == 2) |
| return ROCDL::mfma_i32_32x32x4i8::getOperationName(); |
| if (m == 16 && n == 16 && k == 4 && b == 4) |
| return ROCDL::mfma_i32_16x16x4i8::getOperationName(); |
| if (m == 4 && n == 4 && k == 4 && b == 16) |
| return ROCDL::mfma_i32_4x4x4i8::getOperationName(); |
| if (m == 32 && n == 32 && k == 8 && b == 1) |
| return ROCDL::mfma_i32_32x32x8i8::getOperationName(); |
| if (m == 16 && n == 16 && k == 16 && b == 1) |
| return ROCDL::mfma_i32_16x16x16i8::getOperationName(); |
| if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942) |
| return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); |
| if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942) |
| return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); |
| } |
| |
| if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) { |
| if (m == 16 && n == 16 && k == 4 && b == 1) |
| return ROCDL::mfma_f64_16x16x4f64::getOperationName(); |
| if (m == 4 && n == 4 && k == 4 && b == 4) |
| return ROCDL::mfma_f64_4x4x4f64::getOperationName(); |
| } |
| |
| if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) { |
| // Known to be correct because there are no scalar f8 instructions and |
| // because a length mismatch will have been caught by the verifier. |
| Type sourceBElem = |
| cast<VectorType>(mfma.getSourceB().getType()).getElementType(); |
| if (m == 16 && n == 16 && k == 32 && b == 1) { |
| if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) |
| return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); |
| if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) |
| return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 16 && b == 1) { |
| if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) |
| return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); |
| if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) |
| return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); |
| } |
| } |
| |
| if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) { |
| Type sourceBElem = |
| cast<VectorType>(mfma.getSourceB().getType()).getElementType(); |
| if (m == 16 && n == 16 && k == 32 && b == 1) { |
| if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) |
| return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); |
| if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) |
| return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); |
| } |
| if (m == 32 && n == 32 && k == 16 && b == 1) { |
| if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) |
| return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); |
| if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) |
| return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); |
| } |
| } |
| |
| return std::nullopt; |
| } |
| |
| static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) { |
| return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType) |
| .Case([](Float8E4M3FNType) { return 0u; }) |
| .Case([](Float8E5M2Type) { return 1u; }) |
| .Case([](Float6E2M3FNType) { return 2u; }) |
| .Case([](Float6E3M2FNType) { return 3u; }) |
| .Case([](Float4E2M1FNType) { return 4u; }) |
| .Default([](Type) { return std::nullopt; }); |
| } |
| |
| /// If there is a scaled MFMA instruction for the input element types `aType` |
| /// and `bType`, output type `destType`, problem size M, N, K, and B (number of |
| /// blocks) on the given `chipset`, return a tuple consisting of the |
| /// OperationName of the intrinsic and the type codes that need to be passed to |
| /// that intrinsic. Note that this is also used to implement some un-scaled |
| /// MFMAs, since the compiler represents the ordinary instruction as a "scaled" |
| /// MFMA with a scale of 0. |
| static std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
| mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, |
| uint32_t n, uint32_t k, uint32_t b, Chipset chipset) { |
| aType = getElementTypeOrSelf(aType); |
| bType = getElementTypeOrSelf(bType); |
| destType = getElementTypeOrSelf(destType); |
| |
| if (chipset < kGfx950) |
| return std::nullopt; |
| if (!isa<Float32Type>(destType)) |
| return std::nullopt; |
| |
| std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType); |
| std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType); |
| if (!aTypeCode || !bTypeCode) |
| return std::nullopt; |
| |
| if (m == 32 && n == 32 && k == 64 && b == 1) |
| return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), |
| *aTypeCode, *bTypeCode}; |
| if (m == 16 && n == 16 && k == 128 && b == 1) |
| return std::tuple{ |
| ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode, |
| *bTypeCode}; |
| |
| return std::nullopt; |
| } |
| |
| static std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
| mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) { |
| return mfmaOpToScaledIntrinsic( |
| mfma.getSourceA().getType(), mfma.getSourceB().getType(), |
| mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(), |
| mfma.getBlocks(), chipset); |
| } |
| |
| static std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
| mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { |
| return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(), |
| smfma.getSourceB().getType(), |
| smfma.getDestC().getType(), smfma.getM(), |
| smfma.getN(), smfma.getK(), 1u, chipset); |
| } |
| |
| /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` |
| /// if one exists. This includes checking to ensure the intrinsic is supported |
| /// on the architecture you are compiling for. |
| static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, |
| Chipset chipset) { |
| auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType()); |
| auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType()); |
| auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType()); |
| auto elemSourceType = sourceVectorType.getElementType(); |
| auto elemBSourceType = sourceBVectorType.getElementType(); |
| auto elemDestType = destVectorType.getElementType(); |
| |
| if (elemSourceType.isF16() && elemDestType.isF32()) |
| return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); |
| if (elemSourceType.isBF16() && elemDestType.isF32()) |
| return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); |
| if (elemSourceType.isF16() && elemDestType.isF16()) |
| return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); |
| if (elemSourceType.isBF16() && elemDestType.isBF16()) |
| return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); |
| if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) |
| return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); |
| if (chipset.majorVersion == 11) { |
| if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) |
| return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); |
| } |
| if (chipset.majorVersion >= 12) { |
| if (isa<Float8E4M3FNType>(elemSourceType) && |
| isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) |
| return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); |
| if (isa<Float8E4M3FNType>(elemSourceType) && |
| isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) |
| return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); |
| if (isa<Float8E5M2Type>(elemSourceType) && |
| isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) |
| return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); |
| if (isa<Float8E5M2Type>(elemSourceType) && |
| isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) |
| return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); |
| if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) { |
| bool isWave64 = destVectorType.getNumElements() == 4; |
| // This is the ambiguous case. 8 inputs to the wave64 version means that |
| // we want the 16x16x32 version, but for wave32 they mean the short form. |
| bool has8Inputs = sourceVectorType.getNumElements() == 8; |
| if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs)) |
| return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); |
| return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); |
| } |
| } |
| return std::nullopt; |
| } |
| |
| namespace { |
| struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { |
| MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Type outType = typeConverter->convertType(op.getDestD().getType()); |
| Type intrinsicOutType = outType; |
| if (auto outVecType = dyn_cast<VectorType>(outType)) |
| if (outVecType.getElementType().isBF16()) |
| intrinsicOutType = outVecType.clone(rewriter.getI16Type()); |
| |
| if (chipset.majorVersion != 9 || chipset < kGfx908) |
| return op->emitOpError("MFMA only supported on gfx908+"); |
| uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp()); |
| if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { |
| if (chipset < kGfx942) |
| return op.emitOpError("negation unsupported on older than gfx942"); |
| getBlgpField |= |
| op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); |
| } |
| std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); |
| std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
| maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); |
| if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value()) |
| return op.emitOpError("no intrinsic matching MFMA size on given chipset"); |
| |
| bool isScaled = |
| !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value(); |
| if (isScaled && |
| (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) { |
| return op.emitOpError( |
| "non-default abid, blgp, and cbsz aren't supported on MFMAs that can " |
| "be scaled as those fields are used for type information"); |
| } |
| |
| StringRef intrinsicName = |
| isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic; |
| // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+ |
| // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file. |
| bool allowBf16 = [&]() { |
| if (chipset < kGfx950) |
| return false; |
| if (isScaled) |
| return true; |
| return intrinsicName.contains("16x16x32.bf16") || |
| intrinsicName.contains("32x32x16.bf16"); |
| }(); |
| OperationState loweredOp(loc, intrinsicName); |
| loweredOp.addTypes(intrinsicOutType); |
| loweredOp.addOperands({convertMFMAVectorOperand( |
| rewriter, loc, adaptor.getSourceA(), allowBf16), |
| convertMFMAVectorOperand( |
| rewriter, loc, adaptor.getSourceB(), allowBf16), |
| adaptor.getDestC()}); |
| if (isScaled) { |
| Value zero = createI32Constant(rewriter, loc, 0); |
| auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; |
| loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode), |
| createI32Constant(rewriter, loc, bTypeCode), |
| /*scale A byte=*/zero, /*scale A=*/zero, |
| /*scale B byte=*/zero, /*scale B=*/zero}); |
| } else { |
| loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()), |
| createI32Constant(rewriter, loc, op.getAbid()), |
| createI32Constant(rewriter, loc, getBlgpField)}); |
| }; |
| Value lowered = rewriter.create(loweredOp)->getResult(0); |
| if (outType != intrinsicOutType) |
| lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered); |
| rewriter.replaceOp(op, lowered); |
| return success(); |
| } |
| }; |
| |
| struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> { |
| ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType()); |
| |
| if (chipset.majorVersion != 9 || chipset < kGfx950) |
| return op->emitOpError("scaled MFMA only supported on gfx908+"); |
| std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
| maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); |
| if (!maybeScaledIntrinsic.has_value()) |
| return op.emitOpError( |
| "no intrinsic matching scaled MFMA size on given chipset"); |
| |
| auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; |
| OperationState loweredOp(loc, intrinsicName); |
| loweredOp.addTypes(intrinsicOutType); |
| loweredOp.addOperands( |
| {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()), |
| convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()), |
| adaptor.getDestC()}); |
| Value scalesIdxA = |
| createI32Constant(rewriter, loc, adaptor.getScalesIdxA()); |
| Value scalesIdxB = |
| createI32Constant(rewriter, loc, adaptor.getScalesIdxB()); |
| loweredOp.addOperands( |
| {createI32Constant(rewriter, loc, aTypeCode), |
| createI32Constant(rewriter, loc, bTypeCode), |
| /*scales idx A=*/scalesIdxA, |
| /*scales A*/ |
| castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()), |
| /*scales idx B=*/scalesIdxB, |
| /*scales B*/ |
| castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())}); |
| Value lowered = rewriter.create(loweredOp)->getResult(0); |
| rewriter.replaceOp(op, lowered); |
| return success(); |
| } |
| }; |
| |
| struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { |
| WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| auto outType = |
| typeConverter->convertType<VectorType>(op.getDestD().getType()); |
| if (!outType) |
| return rewriter.notifyMatchFailure(op, "type conversion failed"); |
| |
| if (chipset.majorVersion != 11 && chipset.majorVersion != 12) |
| return op->emitOpError("WMMA only supported on gfx11 and gfx12"); |
| |
| // The WMMA operations represent vectors of bf16s as vectors of i16s, so we |
| // need to bitcast bfloats to i16 and then bitcast them back. |
| VectorType rawOutType = outType; |
| if (outType.getElementType().isBF16()) |
| rawOutType = outType.clone(rewriter.getI16Type()); |
| |
| std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); |
| |
| if (!maybeIntrinsic.has_value()) |
| return op.emitOpError("no intrinsic matching WMMA on the given chipset"); |
| |
| if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) |
| return op.emitOpError("subwordOffset not supported on gfx12+"); |
| |
| OperationState loweredOp(loc, *maybeIntrinsic); |
| loweredOp.addTypes(rawOutType); |
| |
| SmallVector<Value, 4> operands; |
| wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), |
| adaptor.getSourceA(), op.getSourceA(), operands); |
| wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), |
| adaptor.getSourceB(), op.getSourceB(), operands); |
| wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), |
| op.getSubwordOffset(), op.getClamp(), operands); |
| |
| loweredOp.addOperands(operands); |
| Operation *lowered = rewriter.create(loweredOp); |
| |
| Operation *maybeCastBack = lowered; |
| if (rawOutType != outType) |
| maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType, |
| lowered->getResult(0)); |
| rewriter.replaceOp(op, maybeCastBack->getResults()); |
| |
| return success(); |
| } |
| }; |
| |
| struct TransposeLoadOpLowering |
| : public ConvertOpToLLVMPattern<TransposeLoadOp> { |
| TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (chipset != kGfx950) |
| return op.emitOpError("Non-gfx950 chipset not supported"); |
| |
| Location loc = op.getLoc(); |
| auto srcMemRefType = cast<MemRefType>(op.getSrc().getType()); |
| |
| // Elements in subbyte memrefs are stored non-contiguously, |
| // reject if source is sub-byte memref. Use emulated memrefs instead. |
| size_t srcElementSize = |
| srcMemRefType.getElementType().getIntOrFloatBitWidth(); |
| if (srcElementSize < 8) |
| return op.emitOpError("Expect source memref to have at least 8 bits " |
| "element size, got ") |
| << srcElementSize; |
| |
| auto resultType = cast<VectorType>(op.getResult().getType()); |
| Value srcPtr = |
| getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), |
| (adaptor.getSrcIndices())); |
| |
| size_t numElements = resultType.getNumElements(); |
| size_t elementTypeSize = |
| resultType.getElementType().getIntOrFloatBitWidth(); |
| |
| // ROCDL transpose load intrinsics return vectors of 32-bit integers, if |
| // the element size is smaller than 16 bits. |
| Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32, |
| rewriter.getIntegerType(32)); |
| Type llvmResultType = typeConverter->convertType(resultType); |
| |
| switch (elementTypeSize) { |
| case 4: { |
| assert(numElements == 16); |
| auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc, |
| rocdlResultType, srcPtr); |
| rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); |
| break; |
| } |
| case 6: { |
| assert(numElements == 16); |
| auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc, |
| rocdlResultType, srcPtr); |
| rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); |
| break; |
| } |
| case 8: { |
| assert(numElements == 8); |
| auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc, |
| rocdlResultType, srcPtr); |
| rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); |
| break; |
| } |
| case 16: { |
| assert(numElements == 4); |
| rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType, |
| srcPtr); |
| break; |
| } |
| default: |
| return op.emitOpError("Unsupported element size for transpose load"); |
| } |
| return success(); |
| } |
| }; |
| |
| struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> { |
| GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {} |
| |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (chipset.majorVersion < 9 || chipset.majorVersion > 10) |
| return op.emitOpError("pre-gfx9 and post-gfx10 not supported"); |
| |
| Location loc = op.getLoc(); |
| |
| auto srcMemRefType = cast<MemRefType>(op.getSrc().getType()); |
| auto dstMemRefType = cast<MemRefType>(op.getDst().getType()); |
| |
| // TODO: instead of only transfering one element per thread, we could |
| // augment it to transfer multiple elements per thread by issuing multiple |
| // `global_load_lds` instructions. |
| Type transferType = op.getTransferType(); |
| int loadWidth = [&]() -> int { |
| if (auto transferVectorType = dyn_cast<VectorType>(transferType)) { |
| return (transferVectorType.getNumElements() * |
| transferVectorType.getElementTypeBitWidth()) / |
| 8; |
| } |
| return transferType.getIntOrFloatBitWidth() / 8; |
| }(); |
| |
| // Currently only 1, 2, 4, 12 and 16 byte loads are supported. |
| if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth)) |
| return op.emitOpError("chipset unsupported element size"); |
| |
| if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth)) |
| return op.emitOpError("Gather to LDS instructions with 12-byte and " |
| "16-byte load widths are only supported on gfx950"); |
| |
| Value srcPtr = |
| getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), |
| (adaptor.getSrcIndices())); |
| Value dstPtr = |
| getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(), |
| (adaptor.getDstIndices())); |
| |
| rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>( |
| op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth), |
| /*offset=*/rewriter.getI32IntegerAttr(0), |
| /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{}, |
| ArrayAttr{}); |
| |
| return success(); |
| } |
| }; |
| |
| namespace { |
| struct ExtPackedFp8OpLowering final |
| : public ConvertOpToLLVMPattern<ExtPackedFp8Op> { |
| ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter), |
| chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| struct PackedTrunc2xFp8OpLowering final |
| : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { |
| PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, |
| Chipset chipset) |
| : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter), |
| chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| struct PackedStochRoundFp8OpLowering final |
| : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> { |
| PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter, |
| Chipset chipset) |
| : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter), |
| chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(PackedStochRoundFp8Op op, |
| PackedStochRoundFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| struct ScaledExtPackedOpLowering final |
| : public ConvertOpToLLVMPattern<ScaledExtPackedOp> { |
| ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter), |
| chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| struct PackedScaledTruncOpLowering final |
| : public ConvertOpToLLVMPattern<PackedScaledTruncOp> { |
| PackedScaledTruncOpLowering(const LLVMTypeConverter &converter, |
| Chipset chipset) |
| : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter), |
| chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| } // end namespace |
| |
| LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( |
| ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| if (!(chipset == kGfx942 || hasOcpFp8(chipset))) |
| return rewriter.notifyMatchFailure( |
| loc, "Fp8 conversion instructions are not available on target " |
| "architecture and their emulation is not implemented"); |
| Type v4i8 = |
| getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); |
| Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| Type f32 = getTypeConverter()->convertType(op.getResult().getType()); |
| |
| Value source = adaptor.getSource(); |
| auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType()); |
| auto resultVecType = dyn_cast<VectorType>(op.getResult().getType()); |
| Type sourceElemType = getElementTypeOrSelf(op.getSource()); |
| // Extend to a v4i8 |
| if (!sourceVecType || sourceVecType.getNumElements() < 4) { |
| Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8); |
| if (!sourceVecType) { |
| longVec = LLVM::InsertElementOp::create( |
| rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); |
| } else { |
| for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { |
| Value idx = createI32Constant(rewriter, loc, i); |
| Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); |
| longVec = |
| LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); |
| } |
| } |
| source = longVec; |
| } |
| Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); |
| if (resultVecType) { |
| if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { |
| rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source, |
| op.getIndex()); |
| } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { |
| rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source, |
| op.getIndex()); |
| } |
| } else { |
| if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { |
| rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source, |
| op.getIndex()); |
| } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { |
| rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source, |
| op.getIndex()); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( |
| ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| if (chipset != kGfx950) |
| return rewriter.notifyMatchFailure( |
| loc, "Scaled fp conversion instructions are not available on target " |
| "architecture and their emulation is not implemented"); |
| Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| |
| Value source = adaptor.getSource(); |
| Value scale = adaptor.getScale(); |
| |
| VectorType sourceVecType = cast<VectorType>(op.getSource().getType()); |
| Type sourceElemType = sourceVecType.getElementType(); |
| VectorType destVecType = cast<VectorType>(op.getResult().getType()); |
| Type destElemType = destVecType.getElementType(); |
| |
| VectorType packedVecType; |
| if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) { |
| VectorType v4i8 = VectorType::get(4, rewriter.getI8Type()); |
| packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8)); |
| } else if (isa<Float4E2M1FNType>(sourceElemType)) { |
| VectorType v8i4 = VectorType::get(8, rewriter.getI4Type()); |
| packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4)); |
| } else { |
| llvm_unreachable("invalid element type for scaled ext"); |
| } |
| |
| // Extend to a packedVectorType |
| if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { |
| Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType); |
| if (!sourceVecType) { |
| longVec = LLVM::InsertElementOp::create( |
| rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); |
| } else { |
| for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { |
| Value idx = createI32Constant(rewriter, loc, i); |
| Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); |
| longVec = |
| LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); |
| } |
| } |
| source = longVec; |
| } |
| Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); |
| |
| if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16()) |
| rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>( |
| op, destVecType, i32Source, scale, op.getIndex()); |
| else |
| return failure(); |
| |
| return success(); |
| } |
| |
| LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( |
| PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| if (chipset != kGfx950) |
| return rewriter.notifyMatchFailure( |
| loc, "Scaled fp conversion instructions are not available on target " |
| "architecture and their emulation is not implemented"); |
| Type v2i16 = getTypeConverter()->convertType( |
| VectorType::get(2, rewriter.getI16Type())); |
| Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| |
| Type resultType = op.getResult().getType(); |
| Type resultElemType = getElementTypeOrSelf(resultType); |
| VectorType sourceVecType = cast<VectorType>(op.getSource().getType()); |
| Type sourceElemType = sourceVecType.getElementType(); |
| |
| Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16; |
| |
| Value source = adaptor.getSource(); |
| Value scale = adaptor.getScale(); |
| Value existing = adaptor.getExisting(); |
| if (existing) |
| existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing); |
| else |
| existing = LLVM::ZeroOp::create(rewriter, loc, intResultType); |
| |
| if (sourceVecType.getNumElements() < 2) { |
| Value c0 = createI32Constant(rewriter, loc, 0); |
| Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); |
| VectorType v2 = VectorType::get(2, sourceElemType); |
| source = LLVM::ZeroOp::create(rewriter, loc, v2); |
| source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0); |
| } |
| |
| Value sourceA, sourceB; |
| if (sourceElemType.isF32()) { |
| Value c0 = createI32Constant(rewriter, loc, 0); |
| Value c1 = createI32Constant(rewriter, loc, 1); |
| sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); |
| sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1); |
| } |
| |
| Value result; |
| if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType, |
| existing, sourceA, sourceB, |
| scale, op.getIndex()); |
| else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkBf8F16Op::create( |
| rewriter, loc, intResultType, existing, source, scale, op.getIndex()); |
| else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkBf8Bf16Op::create( |
| rewriter, loc, intResultType, existing, source, scale, op.getIndex()); |
| else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType, |
| existing, sourceA, sourceB, |
| scale, op.getIndex()); |
| else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkFp8F16Op::create( |
| rewriter, loc, intResultType, existing, source, scale, op.getIndex()); |
| else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkFp8Bf16Op::create( |
| rewriter, loc, intResultType, existing, source, scale, op.getIndex()); |
| else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType, |
| existing, sourceA, sourceB, |
| scale, op.getIndex()); |
| else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkFp4F16Op::create( |
| rewriter, loc, intResultType, existing, source, scale, op.getIndex()); |
| else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType)) |
| result = ROCDL::CvtScaleF32PkFp4Bf16Op::create( |
| rewriter, loc, intResultType, existing, source, scale, op.getIndex()); |
| else |
| return failure(); |
| |
| result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
| op, getTypeConverter()->convertType(resultType), result); |
| return success(); |
| } |
| |
| LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( |
| PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| if (!(chipset == kGfx942 || hasOcpFp8(chipset))) |
| return rewriter.notifyMatchFailure( |
| loc, "Fp8 conversion instructions are not available on target " |
| "architecture and their emulation is not implemented"); |
| Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| |
| Type resultType = op.getResult().getType(); |
| Type resultElemType = getElementTypeOrSelf(resultType); |
| |
| Value sourceA = adaptor.getSourceA(); |
| Value sourceB = adaptor.getSourceB(); |
| if (!sourceB) |
| sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType()); |
| Value existing = adaptor.getExisting(); |
| if (existing) |
| existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); |
| else |
| existing = LLVM::UndefOp::create(rewriter, loc, i32); |
| |
| Value result; |
| if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) |
| result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB, |
| existing, op.getWordIndex()); |
| else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) |
| result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB, |
| existing, op.getWordIndex()); |
| |
| result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
| op, getTypeConverter()->convertType(resultType), result); |
| return success(); |
| } |
| |
| LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( |
| PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| if (!(chipset == kGfx942 || hasOcpFp8(chipset))) |
| return rewriter.notifyMatchFailure( |
| loc, "Fp8 conversion instructions are not available on target " |
| "architecture and their emulation is not implemented"); |
| Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| |
| Type resultType = op.getResult().getType(); |
| Type resultElemType = getElementTypeOrSelf(resultType); |
| |
| Value source = adaptor.getSource(); |
| Value stoch = adaptor.getStochiasticParam(); |
| Value existing = adaptor.getExisting(); |
| if (existing) |
| existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); |
| else |
| existing = LLVM::UndefOp::create(rewriter, loc, i32); |
| |
| Value result; |
| if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) |
| result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch, |
| existing, op.getStoreIndex()); |
| else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) |
| result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch, |
| existing, op.getStoreIndex()); |
| |
| result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
| op, getTypeConverter()->convertType(resultType), result); |
| return success(); |
| } |
| |
| // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp |
| // operation into the corresponding ROCDL instructions. |
| struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { |
| AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| // Convert the source operand to the corresponding LLVM type |
| Location loc = DppOp.getLoc(); |
| Value src = adaptor.getSrc(); |
| Value old = adaptor.getOld(); |
| Type srcType = src.getType(); |
| Type oldType = old.getType(); |
| Type llvmType = nullptr; |
| if (srcType.getIntOrFloatBitWidth() < 32) { |
| llvmType = rewriter.getI32Type(); |
| } else if (isa<FloatType>(srcType)) { |
| llvmType = (srcType.getIntOrFloatBitWidth() == 32) |
| ? rewriter.getF32Type() |
| : rewriter.getF64Type(); |
| } else if (isa<IntegerType>(srcType)) { |
| llvmType = (srcType.getIntOrFloatBitWidth() == 32) |
| ? rewriter.getI32Type() |
| : rewriter.getI64Type(); |
| } |
| auto llvmSrcIntType = typeConverter->convertType( |
| rewriter.getIntegerType(srcType.getIntOrFloatBitWidth())); |
| |
| // If the source type is less of 32, use bitcast to convert it to i32. |
| auto convertOperand = [&](Value operand, Type operandType) { |
| if (operandType.getIntOrFloatBitWidth() <= 16) { |
| if (llvm::isa<FloatType>(operandType)) { |
| operand = |
| LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand); |
| } |
| auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( |
| 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); |
| Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType); |
| operand = |
| LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand, |
| createI32Constant(rewriter, loc, 0)); |
| operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand); |
| } |
| return operand; |
| }; |
| |
| src = convertOperand(src, srcType); |
| old = convertOperand(old, oldType); |
| |
| // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h |
| enum DppCtrl : unsigned { |
| ROW_SHL0 = 0x100, |
| ROW_SHR0 = 0x110, |
| ROW_ROR0 = 0x120, |
| WAVE_SHL1 = 0x130, |
| WAVE_ROL1 = 0x134, |
| WAVE_SHR1 = 0x138, |
| WAVE_ROR1 = 0x13C, |
| ROW_MIRROR = 0x140, |
| ROW_HALF_MIRROR = 0x141, |
| BCAST15 = 0x142, |
| BCAST31 = 0x143, |
| }; |
| |
| auto kind = DppOp.getKind(); |
| auto permArgument = DppOp.getPermArgument(); |
| uint32_t DppCtrl = 0; |
| |
| switch (kind) { |
| |
| case DPPPerm::quad_perm: |
| if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) { |
| int32_t i = 0; |
| for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) { |
| uint32_t num = elem.getInt(); |
| DppCtrl |= num << (i * 2); |
| i++; |
| } |
| } |
| break; |
| case DPPPerm::row_shl: |
| if (auto intAttr = cast<IntegerAttr>(*permArgument)) { |
| DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0; |
| } |
| break; |
| case DPPPerm::row_shr: |
| if (auto intAttr = cast<IntegerAttr>(*permArgument)) { |
| DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0; |
| } |
| break; |
| case DPPPerm::row_ror: |
| if (auto intAttr = cast<IntegerAttr>(*permArgument)) { |
| DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0; |
| } |
| break; |
| case DPPPerm::wave_shl: |
| DppCtrl = DppCtrl::WAVE_SHL1; |
| break; |
| case DPPPerm::wave_shr: |
| DppCtrl = DppCtrl::WAVE_SHR1; |
| break; |
| case DPPPerm::wave_rol: |
| DppCtrl = DppCtrl::WAVE_ROL1; |
| break; |
| case DPPPerm::wave_ror: |
| DppCtrl = DppCtrl::WAVE_ROR1; |
| break; |
| case DPPPerm::row_mirror: |
| DppCtrl = DppCtrl::ROW_MIRROR; |
| break; |
| case DPPPerm::row_half_mirror: |
| DppCtrl = DppCtrl::ROW_HALF_MIRROR; |
| break; |
| case DPPPerm::row_bcast_15: |
| DppCtrl = DppCtrl::BCAST15; |
| break; |
| case DPPPerm::row_bcast_31: |
| DppCtrl = DppCtrl::BCAST31; |
| break; |
| } |
| |
| // Check for row_mask, bank_mask, bound_ctrl if they exist and create |
| // constants |
| auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt(); |
| auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt(); |
| bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue(); |
| |
| // create a ROCDL_DPPMovOp instruction with the appropriate attributes |
| auto dppMovOp = |
| ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl, |
| rowMask, bankMask, boundCtrl); |
| |
| Value result = dppMovOp.getRes(); |
| if (srcType.getIntOrFloatBitWidth() < 32) { |
| result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result); |
| if (!llvm::isa<IntegerType>(srcType)) { |
| result = LLVM::BitcastOp::create(rewriter, loc, srcType, result); |
| } |
| } |
| |
| // We are replacing the AMDGPU_DPPOp instruction with the new |
| // ROCDL_DPPMovOp instruction |
| rewriter.replaceOp(DppOp, ValueRange(result)); |
| return success(); |
| } |
| }; |
| |
| struct AMDGPUSwizzleBitModeLowering |
| : public ConvertOpToLLVMPattern<SwizzleBitModeOp> { |
| using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Type i32 = rewriter.getI32Type(); |
| Value src = adaptor.getSrc(); |
| SmallVector<Value> decomposed = |
| LLVM::decomposeValue(rewriter, loc, src, i32); |
| unsigned andMask = op.getAndMask(); |
| unsigned orMask = op.getOrMask(); |
| unsigned xorMask = op.getXorMask(); |
| |
| // bit 15 is 0 for the BitMode swizzle. |
| // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ |
| unsigned mask = andMask | (orMask << 5) | (xorMask << 10); |
| Value maskValue = createI32Constant(rewriter, loc, mask); |
| SmallVector<Value> swizzled; |
| for (Value v : decomposed) { |
| Value res = |
| ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue); |
| swizzled.emplace_back(res); |
| } |
| |
| Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType()); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> { |
| using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| |
| AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {} |
| Chipset chipset; |
| |
| LogicalResult |
| matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (chipset < kGfx950) |
| return op->emitOpError("permlane_swap is only supported on gfx950+"); |
| |
| Location loc = op.getLoc(); |
| Type i32 = rewriter.getI32Type(); |
| Value src = adaptor.getSrc(); |
| unsigned rowLength = op.getRowLength(); |
| bool fi = op.getFetchInactive(); |
| bool boundctrl = op.getBoundCtrl(); |
| |
| SmallVector<Value> decomposed = |
| LLVM::decomposeValue(rewriter, loc, src, i32); |
| |
| SmallVector<Value> permuted; |
| for (Value v : decomposed) { |
| Value res; |
| Type i32pair = LLVM::LLVMStructType::getLiteral( |
| rewriter.getContext(), {v.getType(), v.getType()}); |
| |
| if (rowLength == 16) |
| res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi, |
| boundctrl); |
| else if (rowLength == 32) |
| res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi, |
| boundctrl); |
| else |
| llvm_unreachable("unsupported row length"); |
| |
| const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0}); |
| const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1}); |
| |
| const Value isEqual = |
| rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v); |
| |
| // Per `permlane(16|32)` semantics: if the first extracted element equals |
| // 'v', the result is the second element; otherwise it is the first. |
| Value vdstNew = |
| rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0); |
| permuted.emplace_back(vdstNew); |
| } |
| |
| Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType()); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| struct ConvertAMDGPUToROCDLPass |
| : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> { |
| using Base::Base; |
| |
| void runOnOperation() override { |
| MLIRContext *ctx = &getContext(); |
| FailureOr<Chipset> maybeChipset = Chipset::parse(chipset); |
| if (failed(maybeChipset)) { |
| emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); |
| return signalPassFailure(); |
| } |
| |
| RewritePatternSet patterns(ctx); |
| LLVMTypeConverter converter(ctx); |
| populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); |
| LLVMConversionTarget target(getContext()); |
| target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); |
| target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
| target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); |
| if (failed(applyPartialConversion(getOperation(), target, |
| std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::populateAMDGPUMemorySpaceAttributeConversions( |
| TypeConverter &typeConverter) { |
| typeConverter.addTypeAttributeConversion( |
| [](BaseMemRefType type, amdgpu::AddressSpaceAttr as) |
| -> TypeConverter::AttributeConversionResult { |
| MLIRContext *ctx = as.getContext(); |
| Type i64 = IntegerType::get(ctx, 64); |
| switch (as.getValue()) { |
| case amdgpu::AddressSpace::FatRawBuffer: |
| return IntegerAttr::get(i64, 7); |
| case amdgpu::AddressSpace::BufferRsrc: |
| return IntegerAttr::get(i64, 8); |
| case amdgpu::AddressSpace::FatStructuredBuffer: |
| return IntegerAttr::get(i64, 9); |
| } |
| return TypeConverter::AttributeConversionResult::abort(); |
| }); |
| } |
| |
| void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, |
| RewritePatternSet &patterns, |
| Chipset chipset) { |
| populateAMDGPUMemorySpaceAttributeConversions(converter); |
| patterns |
| .add<FatRawBufferCastLowering, |
| RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, |
| RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, |
| RawBufferOpLowering<RawBufferAtomicFaddOp, |
| ROCDL::RawPtrBufferAtomicFaddOp>, |
| RawBufferOpLowering<RawBufferAtomicFmaxOp, |
| ROCDL::RawPtrBufferAtomicFmaxOp>, |
| RawBufferOpLowering<RawBufferAtomicSmaxOp, |
| ROCDL::RawPtrBufferAtomicSmaxOp>, |
| RawBufferOpLowering<RawBufferAtomicUminOp, |
| ROCDL::RawPtrBufferAtomicUminOp>, |
| RawBufferOpLowering<RawBufferAtomicCmpswapOp, |
| ROCDL::RawPtrBufferAtomicCmpSwap>, |
| AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, |
| SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, |
| WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, |
| PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, |
| PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, |
| TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset); |
| patterns.add<AMDGPUSwizzleBitModeLowering>(converter); |
| } |