blob: 6427807e944a1edd3dfb6cf9d17b472d54ec3c3e [file] [log] [blame] [edit]
//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "../LLVMCommon/MemRefDescriptor.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>
namespace mlir {
#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::amdgpu;
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
Location loc, Value val) {
IntegerType i32 = rewriter.getI32Type();
// Force check that `val` is of int type.
auto valTy = cast<IntegerType>(val.getType());
if (i32 == valTy)
return val;
return valTy.getWidth() > 32
? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
: Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
}
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
}
/// Convert an unsigned number `val` to i64.
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
Location loc, Value val) {
IntegerType i64 = rewriter.getI64Type();
// Force check that `val` is of int type.
auto valTy = cast<IntegerType>(val.getType());
if (i64 == valTy)
return val;
return valTy.getWidth() > 64
? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
: Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
}
static Value createI64Constant(ConversionPatternRewriter &rewriter,
Location loc, int64_t value) {
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
}
/// Returns the linear index used to access an element in the memref.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
Location loc, MemRefDescriptor &memRefDescriptor,
ValueRange indices, ArrayRef<int64_t> strides) {
IntegerType i32 = rewriter.getI32Type();
Value index;
for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
if (stride != 1) { // Skip if stride is 1.
Value strideValue =
ShapedType::isDynamic(stride)
? convertUnsignedToI32(rewriter, loc,
memRefDescriptor.stride(rewriter, loc, i))
: LLVM::ConstantOp::create(rewriter, loc, i32, stride);
increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
}
index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
: increment;
}
return index ? index : createI32Constant(rewriter, loc, 0);
}
/// Compute the contents of the `num_records` field for a given memref
/// descriptor - that is, the number of bytes that's one element past the
/// greatest possible valid index into the memref.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
MemRefType memrefType,
MemRefDescriptor &memrefDescriptor,
ArrayRef<int64_t> strides, int64_t elementByteWidth,
amdgpu::Chipset chipset, bool boundsCheck) {
if (chipset >= kGfx1250 && !boundsCheck) {
constexpr int64_t first45bits = (1ll << 45) - 1;
return createI64Constant(rewriter, loc, first45bits);
}
if (memrefType.hasStaticShape() &&
!llvm::any_of(strides, ShapedType::isDynamic)) {
int64_t size = memrefType.getRank() == 0 ? 1 : 0;
ArrayRef<int64_t> shape = memrefType.getShape();
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
size = std::max(shape[i] * strides[i], size);
size = size * elementByteWidth;
return createI64Constant(rewriter, loc, size);
}
Value maxIndex;
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
Value size = memrefDescriptor.size(rewriter, loc, i);
Value stride = memrefDescriptor.stride(rewriter, loc, i);
Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
maxIndex = maxIndex
? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
: maxThisDim;
}
Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
}
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
Value basePointer, Value numRecords,
bool boundsCheck, amdgpu::Chipset chipset,
Value cacheSwizzleStride = nullptr,
unsigned addressSpace = 8) {
// The stride value is generally 0. However, on MI-300 and onward, you can
// enable a cache swizzling mode by setting bit 14 of the stride field
// and setting that stride to a cache stride.
Type i16 = rewriter.getI16Type();
Value stride;
if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
Value cacheStrideZext =
LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
Value swizzleBit = LLVM::ConstantOp::create(
rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
/*isDisjoint=*/true);
} else {
stride = LLVM::ConstantOp::create(rewriter, loc, i16,
rewriter.getI16IntegerAttr(0));
}
uint32_t flags = 0;
if (chipset >= kGfx1250) {
// Flag word:
// bit 0: swizzle
// bit 1: 0 means (total_offset + payload > numRecords)
// 1 means ((total_offset + payload >) numRecords) || ((offset +
// payload) > stride) only applied when swizzle_enable = 0. keep at
// zero.
// whether oob is done depends on numRecords.
// bits 2-3: Type (must be 0)
} else {
// Get the number of elements.
// Flag word:
// bits 0-11: dst sel, ignored by these intrinsics
// bits 12-14: data format (ignored, must be nonzero, 7=float)
// bits 15-18: data format (ignored, must be nonzero, 4=32bit)
// bit 19: In nested heap (0 here)
// bit 20: Behavior on unmap (0 means "return 0 / ignore")
// bits 21-22: Index stride for swizzles (N/A)
// bit 23: Add thread ID (0)
// bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
// bits 25-26: Reserved (0)
// bit 27: Buffer is non-volatile (CDNA only)
// bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
// none, 3 = either swizzles or testing against offset field) RDNA only
// bits 30-31: Type (must be 0)
flags |= (7 << 12) | (4 << 15);
if (chipset.majorVersion >= 10) {
flags |= (1 << 24);
uint32_t oob = boundsCheck ? 3 : 2;
flags |= (oob << 28);
}
}
Value flagsConst = createI32Constant(rewriter, loc, flags);
Type rsrcType =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
loc, rsrcType, basePointer, stride, numRecords, flagsConst);
return resource;
}
namespace {
struct FatRawBufferCastLowering
: public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value memRef = adaptor.getSource();
Value unconvertedMemref = op.getSource();
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
MemRefDescriptor descriptor(memRef);
DataLayout dataLayout = DataLayout::closest(op);
int64_t elementByteWidth =
dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
int64_t unusedOffset = 0;
SmallVector<int64_t, 5> strideVals;
if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
return op.emitOpError("Can't lower non-stride-offset memrefs");
Value numRecords = adaptor.getValidBytes();
if (!numRecords)
numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
elementByteWidth, chipset, adaptor.getBoundsCheck());
Value basePointer =
adaptor.getResetOffset()
? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
memrefType)
: descriptor.alignedPtr(rewriter, loc);
Value offset = adaptor.getResetOffset()
? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(0))
: descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
// so we'll just extract the arrays.
Value sizes = hasSizes
? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
kSizePosInMemRefDescriptor)
: Value{};
Value strides =
hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
kStridePosInMemRefDescriptor)
: Value{};
Value fatPtr = makeBufferRsrc(
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
kAlignedPtrPosInMemRefDescriptor);
result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
kOffsetPosInMemRefDescriptor);
if (hasSizes) {
result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
kSizePosInMemRefDescriptor);
result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
kStridePosInMemRefDescriptor);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
Chipset chipset;
static constexpr uint32_t maxVectorOpWidth = 128;
LogicalResult
matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = gpuOp.getLoc();
Value memref = adaptor.getMemref();
Value unconvertedMemref = gpuOp.getMemref();
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
if (chipset.majorVersion < 9)
return gpuOp.emitOpError("raw buffer ops require GCN or higher");
Value storeData = adaptor.getODSOperands(0)[0];
if (storeData == memref) // no write component to this op
storeData = Value();
Type wantedDataType;
if (storeData)
wantedDataType = storeData.getType();
else
wantedDataType = gpuOp.getODSResults(0)[0].getType();
Value atomicCmpData = Value();
// Operand index 1 of a load is the indices, trying to read them can crash.
if (storeData) {
Value maybeCmpData = adaptor.getODSOperands(1)[0];
if (maybeCmpData != memref)
atomicCmpData = maybeCmpData;
}
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
int64_t elementByteWidth =
dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
// If we want to load a vector<NxT> with total size <= 32
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
// and the total load size is >= 32, use a vector load of N / (bitsize(T) /
// 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
// so bitcast any floats to integers.
Type llvmBufferValType = llvmWantedDataType;
if (atomicCmpData) {
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
llvmBufferValType = this->getTypeConverter()->convertType(
rewriter.getIntegerType(floatType.getWidth()));
}
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t vecLen = dataVector.getNumElements();
uint32_t elemBits =
dataLayout.getTypeSizeInBits(dataVector.getElementType());
uint32_t totalBits = elemBits * vecLen;
bool usePackedFp16 =
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
if (totalBits > maxVectorOpWidth)
return gpuOp.emitOpError(
"Total width of loads or stores must be no more than " +
Twine(maxVectorOpWidth) + " bits, but we call for " +
Twine(totalBits) +
" bits. This should've been caught in validation");
if (!usePackedFp16 && elemBits < 32) {
if (totalBits > 32) {
if (totalBits % 32 != 0)
return gpuOp.emitOpError("Load or store of more than 32-bits that "
"doesn't fit into words. Can't happen\n");
llvmBufferValType = this->typeConverter->convertType(
VectorType::get(totalBits / 32, i32));
} else {
llvmBufferValType = this->typeConverter->convertType(
rewriter.getIntegerType(totalBits));
}
}
}
if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
// Buffer intrinsics doesn't support 1-element vectors, cast them to
// scalars.
if (vecType.getNumElements() == 1)
llvmBufferValType = vecType.getElementType();
}
SmallVector<Value, 6> args;
if (storeData) {
if (llvmBufferValType != llvmWantedDataType) {
Value castForStore = LLVM::BitcastOp::create(
rewriter, loc, llvmBufferValType, storeData);
args.push_back(castForStore);
} else {
args.push_back(storeData);
}
}
if (atomicCmpData) {
if (llvmBufferValType != llvmWantedDataType) {
Value castForCmp = LLVM::BitcastOp::create(
rewriter, loc, llvmBufferValType, atomicCmpData);
args.push_back(castForCmp);
} else {
args.push_back(atomicCmpData);
}
}
// Construct buffer descriptor from memref, attributes
int64_t offset = 0;
SmallVector<int64_t, 5> strides;
if (failed(memrefType.getStridesAndOffset(strides, offset)))
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
MemRefDescriptor memrefDescriptor(memref);
Value ptr = memrefDescriptor.bufferPtr(
rewriter, loc, *this->getTypeConverter(), memrefType);
Value numRecords =
getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
elementByteWidth, chipset, adaptor.getBoundsCheck());
Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
adaptor.getBoundsCheck(), chipset);
args.push_back(resource);
// Indexing (voffset)
Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
adaptor.getIndices(), strides);
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
indexOffset && *indexOffset > 0) {
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
extraOffsetConst)
: extraOffsetConst;
}
voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
args.push_back(voffset);
// SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
// bits 1-2: SLC, DLC = 0 (similarly)
// bit 3: swizzled (0 for raw)
args.push_back(createI32Constant(rewriter, loc, 0));
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
llvmBufferValType);
Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
ArrayRef<NamedAttribute>());
if (lowered->getNumResults() == 1) {
Value replacement = lowered->getResult(0);
if (llvmBufferValType != llvmWantedDataType) {
replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
replacement);
}
rewriter.replaceOp(gpuOp, replacement);
} else {
rewriter.eraseOp(gpuOp);
}
return success();
}
};
// TODO: AMDGPU backend already have all this bitpacking logic, we should move
// it to some common place.
/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
/// Vmcnt = Waitcnt[15:10] (gfx11)
/// Expcnt = Waitcnt[6:4] (pre-gfx11)
/// Expcnt = Waitcnt[2:0] (gfx11)
/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
/// Lgkmcnt = Waitcnt[13:8] (gfx10)
/// Lgkmcnt = Waitcnt[9:4] (gfx11)
static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
unsigned expcnt, unsigned lgkmcnt) {
if (chipset.majorVersion < 9) {
vmcnt = std::min(15u, vmcnt);
expcnt = std::min(7u, expcnt);
lgkmcnt = std::min(15u, lgkmcnt);
return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
}
if (chipset.majorVersion == 9) {
vmcnt = std::min(63u, vmcnt);
expcnt = std::min(7u, expcnt);
lgkmcnt = std::min(15u, lgkmcnt);
unsigned lowBits = vmcnt & 0xF;
unsigned highBits = (vmcnt >> 4) << 14;
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
return lowBits | highBits | otherCnts;
}
if (chipset.majorVersion == 10) {
vmcnt = std::min(63u, vmcnt);
expcnt = std::min(7u, expcnt);
lgkmcnt = std::min(63u, lgkmcnt);
unsigned lowBits = vmcnt & 0xF;
unsigned highBits = (vmcnt >> 4) << 14;
unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
return lowBits | highBits | otherCnts;
}
if (chipset.majorVersion == 11) {
vmcnt = std::min(63u, vmcnt);
expcnt = std::min(7u, expcnt);
lgkmcnt = std::min(63u, lgkmcnt);
return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
}
return failure();
}
struct MemoryCounterWaitOpLowering
: public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset.majorVersion >= 12) {
Location loc = op.getLoc();
if (std::optional<int> ds = adaptor.getDs())
ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
if (std::optional<int> load = adaptor.getLoad())
ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
if (std::optional<int> store = adaptor.getStore())
ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
if (std::optional<int> exp = adaptor.getExp())
ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
if (std::optional<int> tensor = adaptor.getTensor())
ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
rewriter.eraseOp(op);
return success();
}
if (adaptor.getTensor())
return op.emitOpError("unsupported chipset");
auto getVal = [](Attribute attr) -> unsigned {
if (attr)
return cast<IntegerAttr>(attr).getInt();
// This value will be clamped to the maximum value for the chipset.
return 1024;
};
unsigned ds = getVal(adaptor.getDsAttr());
unsigned exp = getVal(adaptor.getExpAttr());
unsigned vmcnt = 1024;
Attribute load = adaptor.getLoadAttr();
Attribute store = adaptor.getStoreAttr();
if (load && store) {
vmcnt = getVal(load) + getVal(store);
} else if (load) {
vmcnt = getVal(load);
} else if (store) {
vmcnt = getVal(store);
}
FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
if (failed(waitcnt))
return op.emitOpError("unsupported chipset");
rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
return success();
}
};
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// This ensures that waits on global memory aren't introduced on
// chips that don't have the BackOffBarrier feature enabled in LLVM.
bool requiresInlineAsm = chipset < kGfx90a;
Attribute mmra =
rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
// Note: while there *is* a workgroup-one-as scope, this, when combined with
// the MMRA, will lead to the fence having no effect. This is because the
// codepaths for an atomic load or store will observe that a
// one-address-space atomic to LDS requires no synchronization because
// operations on LDS are totally ordered with respect to each other, and so
// will not emit the correct waitcnt operations that these fences are
// intended to produce. Therefore, we use a broader type of fence and rely
// on the MMRA to relax it to the semantics we want.
StringRef scope = "workgroup";
auto relFence = LLVM::FenceOp::create(rewriter, loc,
LLVM::AtomicOrdering::release, scope);
relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
if (requiresInlineAsm) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
const char *constraints = "";
LLVM::InlineAsmOp::create(
rewriter, loc,
/*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
/*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
/*is_align_stack=*/false, LLVM::TailCallKind::None,
/*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
} else if (chipset.majorVersion < 12) {
ROCDL::SBarrierOp::create(rewriter, loc);
} else {
ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
}
auto acqFence = LLVM::FenceOp::create(rewriter, loc,
LLVM::AtomicOrdering::acquire, scope);
acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
rewriter.replaceOp(op, acqFence);
return success();
}
};
struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
(uint32_t)op.getOpts());
return success();
}
};
} // namespace
/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
///
/// Specifically:
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
/// allows bf16. Newer MFMAs support bf16 types on operand, check
/// IntrinsicsAMDGPU.td file for reference.
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
/// instead, which is what the f8f6f4 intrinsics use.
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
/// integer.
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
Location loc, Value input,
bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
return LLVM::BitcastOp::create(
rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
vectorType.getNumElements() <= 8)
return LLVM::BitcastOp::create(
rewriter, loc,
rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
32);
return LLVM::BitcastOp::create(
rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
input);
}
}
return input;
}
/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Location loc, Value input,
bool allowBf16 = true) {
Type inputType = input.getType();
auto vectorType = cast<VectorType>(inputType);
// bf16 -> i16 when not allowed (pre-gfx950).
if (vectorType.getElementType().isBF16() && !allowBf16)
return LLVM::BitcastOp::create(
rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
// i8/fp8 vectors -> vector<Nxi32>.
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
return LLVM::BitcastOp::create(
rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
}
return input;
}
/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If `input` is a i8 value, zero extend it to i32
/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
Value input) {
return TypeSwitch<Type, Value>(input.getType())
.Case([&](IntegerType) {
// Handle scalar i8: zero extend to i32.
return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
input);
})
.Case([&](VectorType vectorType) {
// Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
int64_t numElements = vectorType.getNumElements();
assert((numElements == 4 || numElements == 8) &&
"scale operand must be a vector of length 4 or 8");
IntegerType outputType =
(numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
})
.DefaultUnreachable("unexpected input type for scale operand");
}
/// Maps f8 scale element types to WMMA scale format codes.
static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
return TypeSwitch<Type, std::optional<uint32_t>>(elemType)
.Case([](Float8E8M0FNUType) { return 0; })
.Case([](Float8E4M3FNType) { return 2; })
.Default(std::nullopt);
}
/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
/// and scale block size (16 or 32).
static std::optional<StringRef>
getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16) {
if (m == 16 && n == 16 && k == 128)
return isScale16
? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
: ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
if (m == 32 && n == 16 && k == 128)
return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
: ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
return std::nullopt;
}
/// Push an input operand. If it is a float type, nothing to do. If it is
/// an integer type, then we need to also push its signdness (1 for signed, 0
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
/// We also need to convert bfloat inputs to i16 to account for the bfloat
/// intrinsics having been defined before the AMD backend supported bfloat. We
/// similarly need to pack 8-bit float types into integers as if they were i8
/// (which they are for the backend's purposes).
static void wmmaPushInputOperand(
ConversionPatternRewriter &rewriter, Location loc,
const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
Value mlirInput, SmallVectorImpl<Value> &operands,
SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
if (!vectorType) {
operands.push_back(llvmInput);
return;
}
Type elemType = vectorType.getElementType();
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
}
// We need to check the type of the input before conversion to properly test
// for int8. This is because, in LLVM, fp8 type is converted to int8, so the
// fp8/int8 information is lost during the conversion process.
auto mlirInputType = cast<VectorType>(mlirInput.getType());
bool isInputInteger = mlirInputType.getElementType().isInteger();
if (isInputInteger) {
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
bool localIsUnsigned = isUnsigned;
if (elemType.isUnsignedInteger()) {
localIsUnsigned = true;
} else if (elemType.isSignedInteger()) {
localIsUnsigned = false;
}
attrs.push_back(
NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
}
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
Type i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
loc, llvmIntrinsicInType, llvmInput);
// The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
// Add in the zeros here.
if (numBits < 32)
castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
operands.push_back(castInput);
}
/// Push the output operand. For many cases this is only pushing the output in
/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
/// since the same numbers of VGPRs is used, we need to decide if to store the
/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
/// as the instructions have been changed to return fewer registers instead.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
bool clamp, SmallVectorImpl<Value> &operands,
SmallVectorImpl<NamedAttribute> &attrs) {
Type inputType = output.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
attrs.push_back(
NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
} else if (elemType.isInteger(32)) {
attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
}
}
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
}
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
}
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
Chipset chipset) {
uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
b = mfma.getBlocks();
Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
if (sourceElem.isF32() && destElem.isF32()) {
if (mfma.getReducePrecision() && chipset >= kGfx942) {
if (m == 32 && n == 32 && k == 4 && b == 1)
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
if (m == 16 && n == 16 && k == 8 && b == 1)
return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
}
if (m == 32 && n == 32 && k == 1 && b == 2)
return ROCDL::mfma_f32_32x32x1f32::getOperationName();
if (m == 16 && n == 16 && k == 1 && b == 4)
return ROCDL::mfma_f32_16x16x1f32::getOperationName();
if (m == 4 && n == 4 && k == 1 && b == 16)
return ROCDL::mfma_f32_4x4x1f32::getOperationName();
if (m == 32 && n == 32 && k == 2 && b == 1)
return ROCDL::mfma_f32_32x32x2f32::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 1)
return ROCDL::mfma_f32_16x16x4f32::getOperationName();
}
if (sourceElem.isF16() && destElem.isF32()) {
if (chipset >= kGfx950) {
if (m == 32 && n == 32 && k == 16 && b == 1)
return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
if (m == 16 && n == 16 && k == 32 && b == 1)
return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
}
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_f32_32x32x4f16::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_f32_16x16x4f16::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_f32_4x4x4f16::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_f32_32x32x8f16::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
}
if (sourceElem.isBF16() && destElem.isF32()) {
if (chipset >= kGfx950) {
if (m == 32 && n == 32 && k == 16 && b == 1)
return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
if (m == 16 && n == 16 && k == 32 && b == 1)
return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
}
if (chipset >= kGfx90a) {
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
}
if (m == 32 && n == 32 && k == 2 && b == 2)
return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
if (m == 16 && n == 16 && k == 2 && b == 4)
return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
if (m == 4 && n == 4 && k == 2 && b == 16)
return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
if (m == 32 && n == 32 && k == 4 && b == 1)
return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
if (m == 16 && n == 16 && k == 8 && b == 1)
return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
}
if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
if (chipset >= kGfx950) {
if (m == 32 && n == 32 && k == 32 && b == 1)
return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
if (m == 16 && n == 16 && k == 64 && b == 1)
return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
}
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_i32_32x32x4i8::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
return ROCDL::mfma_i32_16x16x4i8::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 16)
return ROCDL::mfma_i32_4x4x4i8::getOperationName();
if (m == 32 && n == 32 && k == 8 && b == 1)
return ROCDL::mfma_i32_32x32x8i8::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_i32_16x16x16i8::getOperationName();
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
}
if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
if (m == 16 && n == 16 && k == 4 && b == 1)
return ROCDL::mfma_f64_16x16x4f64::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 4)
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
return std::nullopt;
}
static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
.Case([](Float8E4M3FNType) { return 0u; })
.Case([](Float8E5M2Type) { return 1u; })
.Case([](Float6E2M3FNType) { return 2u; })
.Case([](Float6E3M2FNType) { return 3u; })
.Case([](Float4E2M1FNType) { return 4u; })
.Default(std::nullopt);
}
/// If there is a scaled MFMA instruction for the input element types `aType`
/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
/// blocks) on the given `chipset`, return a tuple consisting of the
/// OperationName of the intrinsic and the type codes that need to be passed to
/// that intrinsic. Note that this is also used to implement some un-scaled
/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
/// MFMA with a scale of 0.
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
aType = getElementTypeOrSelf(aType);
bType = getElementTypeOrSelf(bType);
destType = getElementTypeOrSelf(destType);
if (chipset < kGfx950)
return std::nullopt;
if (!isa<Float32Type>(destType))
return std::nullopt;
std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
if (!aTypeCode || !bTypeCode)
return std::nullopt;
if (m == 32 && n == 32 && k == 64 && b == 1)
return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
*aTypeCode, *bTypeCode};
if (m == 16 && n == 16 && k == 128 && b == 1)
return std::tuple{
ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
*bTypeCode};
return std::nullopt;
}
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
return mfmaOpToScaledIntrinsic(
mfma.getSourceA().getType(), mfma.getSourceB().getType(),
mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
mfma.getBlocks(), chipset);
}
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
smfma.getSourceB().getType(),
smfma.getDestC().getType(), smfma.getM(),
smfma.getN(), smfma.getK(), 1u, chipset);
}
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// for RDNA3/4 architectures.
static std::optional<StringRef>
wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
Type elemDestType, uint32_t k, bool isRDNA3) {
using fp8 = Float8E4M3FNType;
using bf8 = Float8E5M2Type;
// Handle k == 16 for RDNA3/4.
if (k == 16) {
// Common patterns for RDNA3 and RDNA4.
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
if (elemSourceType.isF16() && elemDestType.isF16())
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isBF16())
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
// RDNA3 specific patterns.
if (isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
return std::nullopt;
}
// RDNA4 specific patterns (fp8/bf8).
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
return std::nullopt;
}
// Handle k == 32 for RDNA4.
if (k == 32 && !isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
}
return std::nullopt;
}
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// for the gfx1250 architecture.
static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
Type elemBSourceType,
Type elemDestType,
uint32_t k) {
using fp8 = Float8E4M3FNType;
using bf8 = Float8E5M2Type;
if (k == 4) {
if (elemSourceType.isF32() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
return std::nullopt;
}
if (k == 32) {
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
if (elemSourceType.isF16() && elemDestType.isF16())
return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isBF16())
return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
return std::nullopt;
}
if (k == 64) {
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
}
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
}
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
}
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
}
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
return std::nullopt;
}
if (k == 128) {
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
}
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
}
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
}
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
}
return std::nullopt;
}
return std::nullopt;
}
/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
/// operation if one exists. This includes checking to ensure the intrinsic is
/// supported on the architecture you are compiling for.
static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
Chipset chipset) {
bool isGfx950 = chipset >= kGfx950;
auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
uint32_t m = op.getM(), n = op.getN(), k = op.getK();
Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
Type destElem = getElementTypeOrSelf(op.getDestC().getType());
if (m == 16 && n == 16 && k == 32) {
if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
}
if (m == 16 && n == 16 && k == 64) {
if (isGfx950) {
if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
}
if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
destElem.isInteger(32))
return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
}
if (m == 16 && n == 16 && k == 128 && isGfx950) {
if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
destElem.isInteger(32))
return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
}
if (m == 32 && n == 32 && k == 16) {
if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
}
if (m == 32 && n == 32 && k == 32) {
if (isGfx950) {
if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
}
if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
destElem.isInteger(32))
return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
}
if (m == 32 && n == 32 && k == 64 && isGfx950) {
if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
destElem.isInteger(32))
return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
}
return std::nullopt;
}
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
Type elemSourceType = sourceVectorType.getElementType();
Type elemBSourceType = sourceBVectorType.getElementType();
Type elemDestType = destVectorType.getElementType();
const uint32_t k = wmma.getK();
const bool isRDNA3 = chipset.majorVersion == 11;
const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
// Handle RDNA3 and RDNA4.
if (isRDNA3 || isRDNA4)
return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
k, isRDNA3);
// Handle gfx1250.
if (chipset == kGfx1250)
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
elemDestType, k);
return std::nullopt;
}
namespace {
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());
Type intrinsicOutType = outType;
if (auto outVecType = dyn_cast<VectorType>(outType))
if (outVecType.getElementType().isBF16())
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
if (chipset.majorVersion != 9 || chipset < kGfx908)
return op->emitOpError("MFMA only supported on gfx908+");
uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
if (chipset < kGfx942)
return op.emitOpError("negation unsupported on older than gfx942");
getBlgpField |=
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
}
std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
return op.emitOpError("no intrinsic matching MFMA size on given chipset");
bool isScaled =
!maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
if (isScaled &&
(adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
return op.emitOpError(
"non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
"be scaled as those fields are used for type information");
}
StringRef intrinsicName =
isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
// Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
// allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
bool allowBf16 = [&]() {
if (chipset < kGfx950)
return false;
if (isScaled)
return true;
return intrinsicName.contains("16x16x32.bf16") ||
intrinsicName.contains("32x32x16.bf16");
}();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands({packSmallFloatVectorOperand(
rewriter, loc, adaptor.getSourceA(), allowBf16),
packSmallFloatVectorOperand(
rewriter, loc, adaptor.getSourceB(), allowBf16),
adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
createI32Constant(rewriter, loc, bTypeCode),
/*scale A byte=*/zero, /*scale A=*/zero,
/*scale B byte=*/zero, /*scale B=*/zero});
} else {
loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid()),
createI32Constant(rewriter, loc, getBlgpField)});
};
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
rewriter.replaceOp(op, lowered);
return success();
}
};
struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
if (chipset.majorVersion != 9 || chipset < kGfx950)
return op->emitOpError("scaled MFMA only supported on gfx908+");
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
if (!maybeScaledIntrinsic.has_value())
return op.emitOpError(
"no intrinsic matching scaled MFMA size on given chipset");
auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
{packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
Value scalesIdxA =
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
Value scalesIdxB =
createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
loweredOp.addOperands(
{createI32Constant(rewriter, loc, aTypeCode),
createI32Constant(rewriter, loc, bTypeCode),
/*scales idx A=*/scalesIdxA,
/*scales A*/
castScaleOperand(rewriter, loc, adaptor.getScalesA()),
/*scales idx B=*/scalesIdxB,
/*scales B*/
castScaleOperand(rewriter, loc, adaptor.getScalesB())});
Value lowered = rewriter.create(loweredOp)->getResult(0);
rewriter.replaceOp(op, lowered);
return success();
}
};
struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto outType =
typeConverter->convertType<VectorType>(op.getDestC().getType());
if (!outType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
// smfmac is supported on gfx942 and gfx950.
if (chipset.majorVersion != 9 || chipset < kGfx942)
return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
bool isGfx950 = chipset >= kGfx950;
Value a = convertSparseMFMAVectorOperand(rewriter, loc,
adaptor.getSourceA(), isGfx950);
Value b = convertSparseMFMAVectorOperand(rewriter, loc,
adaptor.getSourceB(), isGfx950);
Value c = adaptor.getDestC();
std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value())
return op.emitOpError(
"no intrinsic matching sparse MFMA on the given chipset");
// Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
Value sparseIdx = LLVM::BitcastOp::create(
rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
OperationState loweredOp(loc, maybeIntrinsic.value());
loweredOp.addTypes(outType);
loweredOp.addOperands({a, b, c, sparseIdx,
createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid())});
Value lowered = rewriter.create(loweredOp)->getResult(0);
rewriter.replaceOp(op, lowered);
return success();
}
};
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto outType =
typeConverter->convertType<VectorType>(op.getDestD().getType());
if (!outType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
bool isGFX1250 = chipset >= kGfx1250;
// The WMMA operations represent vectors of bf16s as vectors of i16s
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
// bitcast them back.
auto aType = cast<VectorType>(adaptor.getSourceA().getType());
auto bType = cast<VectorType>(adaptor.getSourceB().getType());
auto destCType = cast<VectorType>(adaptor.getDestC().getType());
bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
VectorType rawOutType = outType;
if (castOutToI16)
rawOutType = outType.clone(rewriter.getI16Type());
Value a = adaptor.getSourceA();
if (castAToI16)
a = LLVM::BitcastOp::create(rewriter, loc,
aType.clone(rewriter.getI16Type()), a);
Value b = adaptor.getSourceB();
if (castBToI16)
b = LLVM::BitcastOp::create(rewriter, loc,
bType.clone(rewriter.getI16Type()), b);
Value destC = adaptor.getDestC();
if (castDestCToI16)
destC = LLVM::BitcastOp::create(
rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value())
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
return op.emitOpError("subwordOffset not supported on gfx12+");
SmallVector<Value, 4> operands;
SmallVector<NamedAttribute, 4> attrs;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
op.getSourceA(), operands, attrs, "signA");
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
op.getSourceB(), operands, attrs, "signB");
wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
op.getSubwordOffset(), op.getClamp(), operands,
attrs);
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(rawOutType);
loweredOp.addOperands(operands);
loweredOp.addAttributes(attrs);
Operation *lowered = rewriter.create(loweredOp);
Operation *maybeCastBack = lowered;
if (rawOutType != outType)
maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
lowered->getResult(0));
rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
}
};
struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto outType =
typeConverter->convertType<VectorType>(op.getDestD().getType());
if (!outType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
if (chipset < kGfx1250)
return op->emitOpError("WMMA scale only supported on gfx1250+");
int64_t m = op.getM();
int64_t n = op.getN();
int64_t k = op.getK();
Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
if (!aFmtCode || !bFmtCode)
return op.emitOpError("unsupported element types for scaled_wmma");
// Get scale vector types and determine variant (scale vs scale16).
auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
return op.emitOpError("scaleA and scaleB must have equal vector length");
// Extract scale format from element types.
Type scaleAElemType = scaleAVecType.getElementType();
Type scaleBElemType = scaleBVecType.getElementType();
std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
if (!scaleAFmt || !scaleBFmt)
return op.emitOpError("unsupported scale element types");
// Determine which intrinsic to use based on dimensions.
bool isScale16 = (scaleAVecType.getNumElements() == 8);
std::optional<StringRef> intrinsicName =
getScaledWmmaIntrinsicName(m, n, k, isScale16);
if (!intrinsicName)
return op.emitOpError("unsupported scaled_wmma dimensions: ")
<< m << "x" << n << "x" << k;
SmallVector<NamedAttribute, 8> attrs;
// The f4 variant does not have fmtA and fmtB attributes.
bool is32x16 = (m == 32 && n == 16 && k == 128);
if (!is32x16) {
attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
}
// modC uses default value of 0.
attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
// Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
// half of the wave that is being selected (0 or 1).
attrs.emplace_back(
"scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
attrs.emplace_back(
"scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
// Reuse flags use default value of false.
attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
// Convert typed float vectors to packed format.
Value sourceA =
packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
Value sourceB =
packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
// Pack scale vectors into i32/i64.
Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
// Create the intrinsic call.
OperationState loweredOp(loc, *intrinsicName);
loweredOp.addTypes(outType);
loweredOp.addOperands(
{sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
loweredOp.addAttributes(attrs);
Operation *lowered = rewriter.create(loweredOp);
rewriter.replaceOp(op, lowered->getResults());
return success();
}
};
struct TransposeLoadOpLowering
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset != kGfx950)
return op.emitOpError("Non-gfx950 chipset not supported");
Location loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
// Elements in subbyte memrefs are stored non-contiguously,
// reject if source is sub-byte memref. Use emulated memrefs instead.
size_t srcElementSize =
srcMemRefType.getElementType().getIntOrFloatBitWidth();
if (srcElementSize < 8)
return op.emitOpError("Expect source memref to have at least 8 bits "
"element size, got ")
<< srcElementSize;
auto resultType = cast<VectorType>(op.getResult().getType());
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()));
size_t numElements = resultType.getNumElements();
size_t elementTypeSize =
resultType.getElementType().getIntOrFloatBitWidth();
// ROCDL transpose load intrinsics return vectors of 32-bit integers, if
// the element size is smaller than 16 bits.
Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
rewriter.getIntegerType(32));
Type llvmResultType = typeConverter->convertType(resultType);
switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 16: {
assert(numElements == 4);
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
srcPtr);
break;
}
default:
return op.emitOpError("Unsupported element size for transpose load");
}
return success();
}
};
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
Location loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
// TODO: instead of only transfering one element per thread, we could
// augment it to transfer multiple elements per thread by issuing multiple
// `global_load_lds` instructions.
Type transferType = op.getTransferType();
int loadWidth = [&]() -> int {
if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
return (transferVectorType.getNumElements() *
transferVectorType.getElementTypeBitWidth()) /
8;
}
return transferType.getIntOrFloatBitWidth() / 8;
}();
// Currently only 1, 2, 4, 12 and 16 byte loads are supported.
if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
return op.emitOpError("chipset unsupported element size");
if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
return op.emitOpError("Gather to LDS instructions with 12-byte and "
"16-byte load widths are only supported on gfx950");
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()));
Value dstPtr =
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
(adaptor.getDstIndices()));
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
/*offset=*/rewriter.getI32IntegerAttr(0),
/*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
ArrayAttr{});
return success();
}
};
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct ScaledExtPackedMatrixOpLowering final
: public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ScaledExtPackedMatrixOp op,
ScaledExtPackedMatrixOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedStochRoundFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedStochRoundFp8Op op,
PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct ScaledExtPackedOpLowering final
: public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct PackedScaledTruncOpLowering final
: public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // end namespace
LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type v4i8 =
getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
Value source = adaptor.getSource();
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
if (!sourceVecType) {
longVec = LLVM::InsertElementOp::create(
rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (resultVecType) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
op.getIndex());
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
op.getIndex());
}
} else {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
op.getIndex());
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
op.getIndex());
}
}
return success();
}
int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
int32_t firstScaleByte) {
// When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
// operations, the attributes blockSize, sourceType, scaleWaveHalf, and
// firstScaleByte are merged into a single attribute scaleSel. This is how
// those values are merged together. (Note: scaleWaveHalf isn't a high-level
// attribute but is derifed from firstScaleLane).
assert(llvm::is_contained({16, 32}, blockSize));
assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
const bool isFp8 = bitWidth == 8;
const bool isBlock16 = blockSize == 16;
if (!isFp8) {
int32_t bit0 = isBlock16;
assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
int32_t bit1 = (firstScaleByte == 2) << 1;
assert(llvm::is_contained({0, 1}, scaleWaveHalf));
int32_t bit2 = scaleWaveHalf << 2;
return bit2 | bit1 | bit0;
}
int32_t bit0 = isBlock16;
// firstScaleByte is guaranteed to be defined by two bits.
assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
int32_t bits2and1 = firstScaleByte << 1;
assert(llvm::is_contained({0, 1}, scaleWaveHalf));
int32_t bit3 = scaleWaveHalf << 3;
int32_t bits = bit3 | bits2and1 | bit0;
// These are invalid cases.
assert(!llvm::is_contained(
{0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
return bits;
}
static std::optional<StringRef>
scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
using fp4 = Float4E2M1FNType;
using fp8 = Float8E4M3FNType;
using bf8 = Float8E5M2Type;
using fp6 = Float6E2M3FNType;
using bf6 = Float6E3M2FNType;
if (isa<fp4>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
return std::nullopt;
}
if (isa<fp8>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
return std::nullopt;
}
if (isa<bf8>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
return std::nullopt;
}
if (isa<fp6>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
return std::nullopt;
}
if (isa<bf6>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
return std::nullopt;
}
llvm_unreachable("invalid combination of element types for packed conversion "
"instructions");
}
LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
using fp4 = Float4E2M1FNType;
using fp8 = Float8E4M3FNType;
using bf8 = Float8E5M2Type;
using fp6 = Float6E2M3FNType;
using bf6 = Float6E3M2FNType;
Location loc = op.getLoc();
if (chipset != kGfx1250) {
return rewriter.notifyMatchFailure(
loc,
"Scaled fp packed conversion instructions are not available on target "
"architecture and their emulation is not implemented");
}
// Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
// is being selected.
int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
int32_t firstScaleByte = op.getFirstScaleByte();
int32_t blockSize = op.getBlockSize();
auto sourceType = cast<VectorType>(op.getSource().getType());
auto srcElemType = cast<FloatType>(sourceType.getElementType());
unsigned bitWidth = srcElemType.getWidth();
auto targetType = cast<VectorType>(op.getResult().getType());
auto destElemType = cast<FloatType>(targetType.getElementType());
IntegerType i32 = rewriter.getI32Type();
Value source = adaptor.getSource();
Type llvmResultType = typeConverter->convertType(op.getResult().getType());
Type packedType = nullptr;
if (isa<fp4>(srcElemType)) {
packedType = i32;
packedType = getTypeConverter()->convertType(packedType);
} else if (isa<fp8, bf8>(srcElemType)) {
packedType = VectorType::get(2, i32);
packedType = getTypeConverter()->convertType(packedType);
} else if (isa<fp6, bf6>(srcElemType)) {
packedType = VectorType::get(3, i32);
packedType = getTypeConverter()->convertType(packedType);
} else {
llvm_unreachable("invalid element type for packed scaled ext");
}
if (!packedType || !llvmResultType) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}
std::optional<StringRef> maybeIntrinsic =
scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
if (!maybeIntrinsic.has_value())
return op.emitOpError(
"no intrinsic matching packed scaled conversion on the given chipset");
int32_t scaleSel =
getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
Value castedScale =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes({llvmResultType});
loweredOp.addOperands({castedSource, castedScale});
SmallVector<NamedAttribute, 1> attrs;
attrs.push_back(
NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
loweredOp.addAttributes(attrs);
Operation *lowered = rewriter.create(loweredOp);
rewriter.replaceOp(op, lowered);
return success();
}
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset != kGfx950)
return rewriter.notifyMatchFailure(
loc, "Scaled fp conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Value source = adaptor.getSource();
Value scale = adaptor.getScale();
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
Type sourceElemType = sourceVecType.getElementType();
VectorType destVecType = cast<VectorType>(op.getResult().getType());
Type destElemType = destVecType.getElementType();
VectorType packedVecType;
if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
} else if (isa<Float4E2M1FNType>(sourceElemType)) {
VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
} else {
llvm_unreachable("invalid element type for scaled ext");
}
// Extend to a packedVectorType
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
if (!sourceVecType) {
longVec = LLVM::InsertElementOp::create(
rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else
return failure();
return success();
}
LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset != kGfx950)
return rewriter.notifyMatchFailure(
loc, "Scaled fp conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type v2i16 = getTypeConverter()->convertType(
VectorType::get(2, rewriter.getI16Type()));
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
Type sourceElemType = sourceVecType.getElementType();
Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
Value source = adaptor.getSource();
Value scale = adaptor.getScale();
Value existing = adaptor.getExisting();
if (existing)
existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
else
existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
if (sourceVecType.getNumElements() < 2) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
VectorType v2 = VectorType::get(2, sourceElemType);
source = LLVM::ZeroOp::create(rewriter, loc, v2);
source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
}
Value sourceA, sourceB;
if (sourceElemType.isF32()) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value c1 = createI32Constant(rewriter, loc, 1);
sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
}
Value result;
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
existing, sourceA, sourceB,
scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
result = ROCDL::CvtScaleF32PkBf8F16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
existing, sourceA, sourceB,
scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp8F16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
existing, sourceA, sourceB,
scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp4F16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Value sourceA = adaptor.getSourceA();
Value sourceB = adaptor.getSourceB();
if (!sourceB)
sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
Value existing = adaptor.getExisting();
if (existing)
existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
existing, op.getWordIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
existing, op.getWordIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Value source = adaptor.getSource();
Value stoch = adaptor.getStochiasticParam();
Value existing = adaptor.getExisting();
if (existing)
existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
existing, op.getStoreIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
existing, op.getStoreIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}
// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
// operation into the corresponding ROCDL instructions.
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert the source operand to the corresponding LLVM type
Location loc = DppOp.getLoc();
Value src = adaptor.getSrc();
Value old = adaptor.getOld();
Type srcType = src.getType();
Type oldType = old.getType();
Type llvmType = nullptr;
if (srcType.getIntOrFloatBitWidth() < 32) {
llvmType = rewriter.getI32Type();
} else if (isa<FloatType>(srcType)) {
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
? rewriter.getF32Type()
: rewriter.getF64Type();
} else if (isa<IntegerType>(srcType)) {
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
? rewriter.getI32Type()
: rewriter.getI64Type();
}
auto llvmSrcIntType = typeConverter->convertType(
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
// If the source type is less of 32, use bitcast to convert it to i32.
auto convertOperand = [&](Value operand, Type operandType) {
if (operandType.getIntOrFloatBitWidth() <= 16) {
if (llvm::isa<FloatType>(operandType)) {
operand =
LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
}
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
operand =
LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
createI32Constant(rewriter, loc, 0));
operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
}
return operand;
};
src = convertOperand(src, srcType);
old = convertOperand(old, oldType);
// This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
enum DppCtrl : unsigned {
ROW_SHL0 = 0x100,
ROW_SHR0 = 0x110,
ROW_ROR0 = 0x120,
WAVE_SHL1 = 0x130,
WAVE_ROL1 = 0x134,
WAVE_SHR1 = 0x138,
WAVE_ROR1 = 0x13C,
ROW_MIRROR = 0x140,
ROW_HALF_MIRROR = 0x141,
BCAST15 = 0x142,
BCAST31 = 0x143,
};
auto kind = DppOp.getKind();
auto permArgument = DppOp.getPermArgument();
uint32_t DppCtrl = 0;
switch (kind) {
case DPPPerm::quad_perm: {
auto quadPermAttr = cast<ArrayAttr>(*permArgument);
int32_t i = 0;
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
uint32_t num = elem.getInt();
DppCtrl |= num << (i * 2);
i++;
}
break;
}
case DPPPerm::row_shl: {
auto intAttr = cast<IntegerAttr>(*permArgument);
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
break;
}
case DPPPerm::row_shr: {
auto intAttr = cast<IntegerAttr>(*permArgument);
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
break;
}
case DPPPerm::row_ror: {
auto intAttr = cast<IntegerAttr>(*permArgument);
DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
break;
}
case DPPPerm::wave_shl:
DppCtrl = DppCtrl::WAVE_SHL1;
break;
case DPPPerm::wave_shr:
DppCtrl = DppCtrl::WAVE_SHR1;
break;
case DPPPerm::wave_rol:
DppCtrl = DppCtrl::WAVE_ROL1;
break;
case DPPPerm::wave_ror:
DppCtrl = DppCtrl::WAVE_ROR1;
break;
case DPPPerm::row_mirror:
DppCtrl = DppCtrl::ROW_MIRROR;
break;
case DPPPerm::row_half_mirror:
DppCtrl = DppCtrl::ROW_HALF_MIRROR;
break;
case DPPPerm::row_bcast_15:
DppCtrl = DppCtrl::BCAST15;
break;
case DPPPerm::row_bcast_31:
DppCtrl = DppCtrl::BCAST31;
break;
}
// Check for row_mask, bank_mask, bound_ctrl if they exist and create
// constants
auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
auto dppMovOp =
ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
rowMask, bankMask, boundCtrl);
Value result = dppMovOp.getRes();
if (srcType.getIntOrFloatBitWidth() < 32) {
result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
if (!llvm::isa<IntegerType>(srcType)) {
result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
}
}
// We are replacing the AMDGPU_DPPOp instruction with the new
// ROCDL_DPPMovOp instruction
rewriter.replaceOp(DppOp, ValueRange(result));
return success();
}
};
struct AMDGPUSwizzleBitModeLowering
: public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, src, i32);
unsigned andMask = op.getAndMask();
unsigned orMask = op.getOrMask();
unsigned xorMask = op.getXorMask();
// bit 15 is 0 for the BitMode swizzle.
// https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
Value maskValue = createI32Constant(rewriter, loc, mask);
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res =
ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
swizzled.emplace_back(res);
}
Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
rewriter.replaceOp(op, result);
return success();
}
};
struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx950)
return op->emitOpError("permlane_swap is only supported on gfx950+");
Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
unsigned rowLength = op.getRowLength();
bool fi = op.getFetchInactive();
bool boundctrl = op.getBoundCtrl();
SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, src, i32);
SmallVector<Value> permuted;
for (Value v : decomposed) {
Value res;
Type i32pair = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {v.getType(), v.getType()});
if (rowLength == 16)
res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
else if (rowLength == 32)
res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
else
llvm_unreachable("unsupported row length");
Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
LLVM::ICmpPredicate::eq, vdst0, v);
// Per `permlane(16|32)` semantics: if the first extracted element equals
// 'v', the result is the second element; otherwise it is the first.
Value vdstNew =
LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
permuted.emplace_back(vdstNew);
}
Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
rewriter.replaceOp(op, result);
return success();
}
};
static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
Value accumulator, Value value, int64_t shift) {
shift = shift % 32;
Value shiftAmount;
if (shift != 0) {
shiftAmount = createI32Constant(rewriter, loc, shift % 32);
value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
}
if (matchPattern(accumulator, mlir::m_Zero()))
return value;
constexpr bool isDisjoint = true;
return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
}
template <typename BaseOp>
struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(BaseOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx1250)
return op->emitOpError("make_dma_base is only supported on gfx1250");
Location loc = op.getLoc();
constexpr int32_t constlen = 4;
Value consts[constlen];
for (int64_t i = 0; i < constlen; ++i)
consts[i] = createI32Constant(rewriter, loc, i);
constexpr int32_t sgprslen = constlen;
Value sgprs[sgprslen];
for (int64_t i = 0; i < sgprslen; ++i) {
sgprs[i] = consts[0];
}
sgprs[0] = consts[1];
if constexpr (BaseOp::isGather()) {
sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
auto type = cast<TDMGatherBaseType>(op.getResult().getType());
Type indexType = type.getIndexType();
unsigned indexSize = indexType.getIntOrFloatBitWidth();
assert(llvm::is_contained({16u, 32u}, indexSize) &&
"expected index_size to be 16 or 32");
unsigned idx = (indexSize / 16) - 1;
if (idx)
sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
}
ValueRange ldsIndices = adaptor.getLdsIndices();
Value lds = adaptor.getLds();
auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
Value ldsPtr = ConvertToLLVMPattern::getStridedElementPtr(
rewriter, loc, ldsMemRefType, lds, ldsIndices);
ValueRange globalIndices = adaptor.getGlobalIndices();
Value global = adaptor.getGlobal();
auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
Value globalPtr = ConvertToLLVMPattern::getStridedElementPtr(
rewriter, loc, globalMemRefType, global, globalIndices);
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
Value castForGlobalAddr =
LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
createI64Constant(rewriter, loc, 32));
Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
assert(v4i32 && "expected type conversion to succeed");
Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
result =
LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
rewriter.replaceOp(op, result);
return success();
}
};
template <typename DescriptorOp>
struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
Chipset chipset;
Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0) const {
Value mask = op.getWorkgroupMask();
if (!mask)
return sgpr0;
Type i16 = rewriter.getI16Type();
mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
Type i32 = rewriter.getI32Type();
Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
}
Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts) const {
unsigned elementTypeWidthInBits = op.getElementTypeWidth();
assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
"expected type width to be 8, 16, 32, or 64.");
int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
Value size = consts[idx];
return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
}
Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts) const {
if (!adaptor.getAtomicBarrierAddress())
return sgpr0;
return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
}
Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts) const {
if (!adaptor.getGlobalIncrement())
return sgpr0;
// Value is ignored when in gather mode.
// TODO: emit error earlier?
return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
}
Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts) const {
if (!op.getPadAmount())
return sgpr0;
return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
}
Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts) const {
if (!op.getWorkgroupMask())
return sgpr0;
return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
}
Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts) const {
if (!op.getPadAmount())
return sgpr0;
// pre-condition: padInterval can be a power of two between 2 and 256.
// TODO: Validation if the value breaks the pre-condition.
// If the pre-condition fails, there is a possibility of
// affecting the higher bits. In a following PR implement
// RuntimeVerifiableOpInterface that instruments conditions that need to be
// checked at runtime.
IntegerType i32 = rewriter.getI32Type();
Value padInterval = adaptor.getPadInterval();
padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
padInterval, false);
padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
// post-condition: padInterval can be a value between 0 and 7.
return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
}
Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts) const {
if (!op.getPadAmount())
return sgpr0;
// pre-condition: padAmount is a value between 1-128.
// TODO: Validation if the value breaks the pre-condition.
// If the pre-condition fails, there is a possibility of
// affecting the higher bits. In a following PR implement
// RuntimeVerifiableOpInterface that instruments conditions that need to be
// checked at runtime.
Value padAmount = adaptor.getPadAmount();
padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
// post-condition: padAmount is a value between 0-127.
return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
}
Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Location loc, Value sgpr1,
ArrayRef<Value> consts) const {
if (!adaptor.getAtomicBarrierAddress())
return sgpr1;
Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
auto barrierAddressTy =
cast<MemRefType>(op.getAtomicBarrierAddress().getType());
ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
rewriter, loc, barrierAddressTy, atomicBarrierAddress,
atomicBarrierIndices);
IntegerType i32 = rewriter.getI32Type();
// pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
// that the 3 LSBs are zero.
// TODO: Validation if the value breaks the pre-condition.
// In a following PR implement RuntimeVerifiableOpInterface
// that instruments conditions that need to be checked at runtime.
atomicBarrierAddress =
LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
atomicBarrierAddress =
LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
Value mask = createI32Constant(rewriter, loc, 0xFFFF);
atomicBarrierAddress =
LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
}
std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Location loc, Value sgpr1, Value sgpr2,
ArrayRef<Value> consts, uint64_t dimX,
uint32_t offset) const {
ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
SmallVector<OpFoldResult> mixedGlobalSizes =
getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
if (mixedGlobalSizes.size() <= dimX)
return {sgpr1, sgpr2};
OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
// pre-condition: tensorDimX is less than 2^32-1
// TODO: Validation if the value breaks the pre-condition.
// In a following PR implement RuntimeVerifiableOpInterface that instruments
// conditions that need to be checked at runtime. This could also be fixed
// by saying that mixedGlobalSizes is a DynamicI32List.
Value tensorDimX;
if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
tensorDimX =
createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
} else {
IntegerType i32 = rewriter.getI32Type();
tensorDimX = cast<Value>(tensorDimXOpFoldResult);
tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
}
sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
Value c16 = createI32Constant(rewriter, loc, 16);
Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
return {sgpr1, sgpr2};
}
std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Location loc, Value sgpr1, Value sgpr2,
ArrayRef<Value> consts) const {
return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
48);
}
std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Location loc, Value sgpr2, Value sgpr3,
ArrayRef<Value> consts) const {
return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
80);
}
Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr, ArrayRef<Value> consts, size_t dimX,
int64_t offset) const {
ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
SmallVector<OpFoldResult> mixedSharedSizes =
getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
if (mixedSharedSizes.size() <= dimX)
return sgpr;
OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
// pre-condition: tileDimX is less than 2^16-1
// TODO: Validation if the value breaks the pre-condition.
// If the pre-condition fails, there is a possibility of
// affecting the higher bits. In a following PR implement
// RuntimeVerifiableOpInterface that instruments conditions that need to be
// checked at runtime. This could also be fixed by saying that
// mixedSharedSizes is a DynamicI16List.
Value tileDimX;
if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
tileDimX =
createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
} else {
IntegerType i32 = rewriter.getI32Type();
tileDimX = cast<Value>(tileDimXOpFoldResult);
tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
}
return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
}
Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr3, ArrayRef<Value> consts) const {
return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
}
Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr4, ArrayRef<Value> consts) const {
return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
}
Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr4, ArrayRef<Value> consts) const {
auto type = cast<VectorType>(op.getIndices().getType());
ArrayRef<int64_t> shape = type.getShape();
assert(shape.size() == 1 && "expected shape to be of rank 1.");
unsigned length = shape.back();
assert(0 < length && length <= 16 && "expected length to be at most 16.");
Value value = createI32Constant(rewriter, loc, length);
return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
}
Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Location loc, Value sgpr4,
ArrayRef<Value> consts) const {
if constexpr (DescriptorOp::isGather())
return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
}
Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr4, ArrayRef<Value> consts) const {
// Value is ignored when in gather mode.
if constexpr (DescriptorOp::isGather())
return sgpr4;
return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
}
std::pair<Value, Value>
setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgprY, Value sgprZ, ArrayRef<Value> consts,
size_t dimX, int64_t offset) const {
ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
SmallVector<OpFoldResult> mixedGlobalStrides =
getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
if (mixedGlobalStrides.size() <= (dimX + 1))
return {sgprY, sgprZ};
OpFoldResult tensorDimXStrideOpFoldResult =
*(mixedGlobalStrides.rbegin() + dimX + 1);
// pre-condition: tensorDimXStride is less than 2^48-1
// TODO: Validation if the value breaks the pre-condition.
// In a following PR implement RuntimeVerifiableOpInterface that instruments
// conditions that need to be checked at runtime.
Value tensorDimXStride;
if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
tensorDimXStride =
createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
else
tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
constexpr int64_t first48bits = (1ll << 48) - 1;
Value mask = createI64Constant(rewriter, loc, first48bits);
tensorDimXStride =
LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
IntegerType i32 = rewriter.getI32Type();
Value tensorDimXStrideLow =
LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
Value shiftVal = createI64Constant(rewriter, loc, shift);
Value tensorDimXStrideHigh =
LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
tensorDimXStrideHigh =
LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
offset + shift);
return {sgprY, sgprZ};
}
std::pair<Value, Value>
setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
0, 160);
}
std::pair<Value, Value>
setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
// Value is ignored when in gather mode.
if constexpr (DescriptorOp::isGather())
return {sgpr5, sgpr6};
return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
1, 208);
}
Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> consts) const {
Value sgprs[8];
for (int64_t i = 0; i < 8; ++i) {
sgprs[i] = consts[0];
}
sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
sgprs[1] =
setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
std::tie(sgprs[1], sgprs[2]) =
setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
std::tie(sgprs[2], sgprs[3]) =
setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
sgprs[4] =
setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
IntegerType i32 = rewriter.getI32Type();
Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
assert(v8i32 && "expected type conversion to succeed");
Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
dgroup1 =
LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
}
return dgroup1;
}
Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
int64_t offset) const {
ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
SmallVector<OpFoldResult> mixedGlobalSizes =
getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
return sgpr0;
OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
Value tensorDimX;
if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
tensorDimX =
createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
} else {
IntegerType i32 = rewriter.getI32Type();
tensorDimX = cast<Value>(tensorDimXOpFoldResult);
tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
}
return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
}
Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, ArrayRef<Value> consts) const {
return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
}
Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
Location loc, Value accumulator,
Value value, int64_t shift) const {
IntegerType i32 = rewriter.getI32Type();
value = LLVM::TruncOp::create(rewriter, loc, i32, value);
return setValueAtOffset(rewriter, loc, accumulator, value, shift);
}
Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr1, ArrayRef<Value> consts,
int64_t offset) const {
Value ldsAddrIncrement = adaptor.getLdsIncrement();
return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
}
std::pair<Value, Value>
setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
int64_t offset) const {
Value globalAddrIncrement = adaptor.getGlobalIncrement();
sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
globalAddrIncrement, offset);
Value shift = createI64Constant(rewriter, loc, 32);
globalAddrIncrement =
LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
globalAddrIncrement, offset + 32);
Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
return {sgpr2, sgpr3};
}
Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Location loc, Value sgpr1,
ArrayRef<Value> consts) const {
Value ldsIncrement = op.getLdsIncrement();
constexpr int64_t dim = 3;
constexpr int64_t offset = 32;
if (!ldsIncrement)
return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
offset);
return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
offset);
}
std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
Value globalIncrement = op.getGlobalIncrement();
constexpr int32_t dim = 2;
constexpr int32_t offset = 64;
if (!globalIncrement)
return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
consts, dim, offset);
return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
consts, offset);
}
Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr3, ArrayRef<Value> consts,
int32_t offset) const {
Value iterationCount = adaptor.getIterationCount();
IntegerType i32 = rewriter.getI32Type();
// pre-condition: iterationCount is in the inclusive interval [1, 256].
// TODO: validation if the value breaks the pre-condition.
// If the pre-condition fails, there is a possibility of
// affecting the higher bits. In a following PR implement
// RuntimeVerifiableOpInterface that instruments conditions that need to be
// checked at runtime.
iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
iterationCount =
LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
}
Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Location loc, Value sgpr3,
ArrayRef<Value> consts) const {
Value iterateCount = op.getIterationCount();
constexpr int32_t dim = 2;
constexpr int32_t offset = 112;
if (!iterateCount)
return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
offset);
return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
}
Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> consts) const {
if constexpr (DescriptorOp::isGather())
return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
}
Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> consts) const {
IntegerType i32 = rewriter.getI32Type();
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
assert(v4i32 && "expected type conversion to succeed.");
bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
if (onlyNeedsTwoDescriptors)
return LLVM::ZeroOp::create(rewriter, loc, v4i32);
constexpr int64_t sgprlen = 4;
Value sgprs[sgprlen];
for (int i = 0; i < sgprlen; ++i)
sgprs[i] = consts[0];
sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
sgprs[1], consts);
std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
sgprs[3] =
setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
dgroup2 =
LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
return dgroup2;
}
Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> consts, bool firstHalf) const {
IntegerType i32 = rewriter.getI32Type();
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
assert(v4i32 && "expected type conversion to succeed.");
Value indices = adaptor.getIndices();
auto vectorType = cast<VectorType>(indices.getType());
unsigned length = vectorType.getShape().back();
Type elementType = vectorType.getElementType();
unsigned maxLength = elementType == i32 ? 4 : 8;
int32_t offset = firstHalf ? 0 : maxLength;
unsigned discountedLength =
std::max(static_cast<int32_t>(length - offset), 0);
unsigned targetSize = std::min(maxLength, discountedLength);
SmallVector<Value> indicesVector;
for (unsigned i = offset; i < targetSize + offset; ++i) {
Value idx;
if (i < consts.size())
idx = consts[i];
else
idx = createI32Constant(rewriter, loc, i);
Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
indicesVector.push_back(elem);
}
SmallVector<Value> indicesI32Vector;
if (elementType == i32) {
indicesI32Vector = indicesVector;
} else {
for (unsigned i = 0; i < targetSize; ++i) {
Value index = indicesVector[i];
indicesI32Vector.push_back(
LLVM::ZExtOp::create(rewriter, loc, i32, index));
}
if ((targetSize % 2) != 0)
// Add padding when not divisible by two.
indicesI32Vector.push_back(consts[0]);
}
SmallVector<Value> indicesToInsert;
if (elementType == i32) {
indicesToInsert = indicesI32Vector;
} else {
unsigned size = indicesI32Vector.size() / 2;
for (unsigned i = 0; i < size; ++i) {
Value first = indicesI32Vector[2 * i];
Value second = indicesI32Vector[2 * i + 1];
Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
indicesToInsert.push_back(joined);
}
}
Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
dgroup =
LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
return dgroup;
}
Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> consts) const {
return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
}
std::pair<Value, Value>
setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
constexpr int32_t dim = 3;
constexpr int32_t offset = 0;
return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
dim, offset);
}
std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Location loc, Value sgpr1, Value sgpr2,
ArrayRef<Value> consts) const {
constexpr int32_t dim = 4;
constexpr int32_t offset = 48;
return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
offset);
}
Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr2, ArrayRef<Value> consts) const {
constexpr int32_t dim = 4;
constexpr int32_t offset = 80;
return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
}
Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> consts) const {
if constexpr (DescriptorOp::isGather())
return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
}
Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> consts) const {
IntegerType i32 = rewriter.getI32Type();
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
assert(v4i32 && "expected type conversion to succeed.");
bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
if (onlyNeedsTwoDescriptors)
return LLVM::ZeroOp::create(rewriter, loc, v4i32);
constexpr int32_t sgprlen = 4;
Value sgprs[sgprlen];
for (int i = 0; i < sgprlen; ++i)
sgprs[i] = consts[0];
std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
std::tie(sgprs[1], sgprs[2]) =
setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
dgroup3 =
LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
return dgroup3;
}
Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> consts) const {
return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
}
LogicalResult
matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx1250)
return op->emitOpError(
"make_dma_descriptor is only supported on gfx1250");
Location loc = op.getLoc();
SmallVector<Value> consts;
for (int64_t i = 0; i < 8; ++i)
consts.push_back(createI32Constant(rewriter, loc, i));
Value dgroup0 = this->getDGroup0(adaptor);
Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
rewriter.replaceOpWithMultiple(op, {results});
return success();
}
};
template <typename SourceOp, typename TargetOp>
struct AMDGPUTensorLoadStoreOpLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Adaptor = typename ConvertOpToLLVMPattern<SourceOp>::OneToNOpAdaptor;
AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(SourceOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx1250)
return op->emitOpError("is only supported on gfx1250");
ValueRange desc = adaptor.getDesc();
rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
desc[3], /*cachePolicy=*/0,
/*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr,
/*tbaa=*/nullptr);
return success();
}
};
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
void runOnOperation() override {
MLIRContext *ctx = &getContext();
FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
if (failed(maybeChipset)) {
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
return signalPassFailure();
}
RewritePatternSet patterns(ctx);
LLVMTypeConverter converter(ctx);
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
LLVMConversionTarget target(getContext());
target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::amdgpu::populateCommonGPUTypeAndAttributeConversions(
TypeConverter &typeConverter) {
populateGpuMemorySpaceAttributeConversions(
typeConverter, [](gpu::AddressSpace space) {
switch (space) {
case gpu::AddressSpace::Global:
return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
case gpu::AddressSpace::Workgroup:
return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
case gpu::AddressSpace::Private:
return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
}
llvm_unreachable("unknown address space enum value");
});
}
void mlir::populateAMDGPUTypeAndAttributeConversions(
TypeConverter &typeConverter) {
typeConverter.addTypeAttributeConversion(
[](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
-> TypeConverter::AttributeConversionResult {
MLIRContext *ctx = as.getContext();
Type i64 = IntegerType::get(ctx, 64);
switch (as.getValue()) {
case amdgpu::AddressSpace::FatRawBuffer:
return IntegerAttr::get(i64, 7);
case amdgpu::AddressSpace::BufferRsrc:
return IntegerAttr::get(i64, 8);
case amdgpu::AddressSpace::FatStructuredBuffer:
return IntegerAttr::get(i64, 9);
}
return TypeConverter::AttributeConversionResult::abort();
});
typeConverter.addConversion([&](TDMBaseType type) -> Type {
Type i32 = IntegerType::get(type.getContext(), 32);
return typeConverter.convertType(VectorType::get(4, i32));
});
typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
Type i32 = IntegerType::get(type.getContext(), 32);
return typeConverter.convertType(VectorType::get(4, i32));
});
typeConverter.addConversion(
[&](TDMDescriptorType type,
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
Type i32 = IntegerType::get(type.getContext(), 32);
Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
return success();
});
auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
ValueRange inputs,
Location loc) -> SmallVector<Value> {
// Only create unrealized_conversion_cast for TDMDescriptorType.
// All other types which are not expected, should be
// materialized by other target materialization functions.
if (inputs.size() != 1)
return {};
if (!isa<TDMDescriptorType>(inputs[0].getType()))
return {};
auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
return cast.getResults();
};
typeConverter.addTargetMaterialization(addUnrealizedCast);
}
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
populateAMDGPUTypeAndAttributeConversions(converter);
patterns
.add<FatRawBufferCastLowering,
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
RawBufferOpLowering<RawBufferAtomicFaddOp,
ROCDL::RawPtrBufferAtomicFaddOp>,
RawBufferOpLowering<RawBufferAtomicFmaxOp,
ROCDL::RawPtrBufferAtomicFmaxOp>,
RawBufferOpLowering<RawBufferAtomicSmaxOp,
ROCDL::RawPtrBufferAtomicSmaxOp>,
RawBufferOpLowering<RawBufferAtomicUminOp,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering, TransposeLoadOpLowering,
AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
ROCDL::TensorLoadToLDSOp>,
AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
ROCDL::TensorStoreFromLDSOp>>(
converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}