blob: 6ce80c7456d6aa873e82f6a1e2c22bb290accd92 [file]
//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the types and operation details for the NVVM IR dialect in
// MLIR, and the LLVM IR dialect. It also registers the dialect.
//
// The NVVM dialect only contains GPU specific additions on top of the general
// LLVM dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <optional>
#include <string>
using namespace mlir;
using namespace NVVM;
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
//===----------------------------------------------------------------------===//
// Helper/Utility methods
//===----------------------------------------------------------------------===//
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
}
static bool isPtrInGenericSpace(mlir::Value ptr) {
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
}
static bool isPtrInSharedCTASpace(mlir::Value ptr) {
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
}
static bool isPtrInSharedClusterSpace(mlir::Value ptr) {
return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
}
static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
llvm::Value *ptr,
NVVMMemorySpace targetAS) {
unsigned AS = static_cast<unsigned>(targetAS);
return builder.CreateAddrSpaceCast(
ptr, llvm::PointerType::get(builder.getContext(), AS));
}
// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
static llvm::nvvm::CTAGroupKind
getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
switch (ctaGroup) {
case NVVM::CTAGroupKind::CTA_1:
return llvm::nvvm::CTAGroupKind::CG_1;
case NVVM::CTAGroupKind::CTA_2:
return llvm::nvvm::CTAGroupKind::CG_2;
}
llvm_unreachable("unsupported cta_group value");
}
//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
// This verifier is shared among the following Ops:
// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
size_t numIm2ColOffsets,
Location loc) {
if (tensorDims < 1 || tensorDims > 5)
return emitError(loc, "expects coordinates between 1 to 5 dimension");
// For Im2Col mode, there are two constraints:
if (isIm2Col) {
// 1. Tensor must always be at least 3-d.
if (tensorDims < 3)
return emitError(
loc,
"to use im2col mode, the tensor has to be at least 3-dimensional");
// 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
return emitError(
loc, "im2col offsets must be 2 less than number of coordinates");
}
return success();
}
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
TMAStoreMode mode = getMode();
// We lower through inline-ptx when getPredicate() is true.
// a) Only TILE mode is supported
// b) Cache-hint is not supported
if (getPredicate()) {
if (mode != TMAStoreMode::TILE)
return emitError("Inline-ptx lowering supported only for Tile mode.");
if (getL2CacheHint())
return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
}
size_t dims = getCoordinates().size();
switch (mode) {
case TMAStoreMode::TILE:
return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
case TMAStoreMode::IM2COL:
return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
case TMAStoreMode::TILE_SCATTER4:
if (dims != 5)
return emitError("Scatter4 mode expects 5 coordinates");
}
return success();
}
LogicalResult CpAsyncOp::verify() {
if (getModifier() != LoadCacheModifierKind::CG &&
getModifier() != LoadCacheModifierKind::CA)
return emitError("Only CG and CA cache modifiers are supported.");
if (getSize() != 4 && getSize() != 8 && getSize() != 16)
return emitError("expected byte size to be either 4, 8 or 16.");
if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
return emitError("CG cache modifier is only support for 16 bytes copy.");
return success();
}
// This verify params can be shared across TMA Load and Prefetch Ops.
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
TMALoadMode mode, Location loc) {
if (tensorDims < 1 || tensorDims > 5)
return emitError(loc, "expects coordinates between 1 to 5 dimension");
auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
size_t expectedIm2colOff) -> LogicalResult {
if (isIm2col && (tensorDims < 3))
return emitError(loc)
<< "to use " << stringifyEnum(mode)
<< " mode, the tensor has to be at least 3-dimensional";
if (numIm2colOff != expectedIm2colOff)
return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
<< " (provided " << numIm2colOff << ")";
return success();
};
switch (mode) {
case TMALoadMode::TILE:
return checkTMALoadParams(mode, false, 0);
case TMALoadMode::IM2COL:
return checkTMALoadParams(mode, true, tensorDims - 2);
case TMALoadMode::IM2COL_W:
case TMALoadMode::IM2COL_W_128:
return checkTMALoadParams(mode, true, 2);
case TMALoadMode::TILE_GATHER4:
return (tensorDims == 5)
? checkTMALoadParams(mode, false, 0)
: emitError(loc, "Gather4 mode expects 5 coordinates");
}
return success();
}
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
getMode(), getLoc());
}
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
TMALoadMode mode = getMode();
bool isCTAOnly = getIsCTAOnly();
if (getPredicate()) { // Inline-asm based lowering
if (isCTAOnly)
return emitError("Predicate is supported only for shared::cluster mode.");
if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
return emitError(
"Predicate is supported only for Tile and Im2col modes.");
} else { // Intrinsics-based lowering
NVVMMemorySpace expectedAS =
isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
.getAddressSpace();
if (AS != expectedAS)
return emitError()
<< (isCTAOnly
? "Shared::cta destination requires address-space 3."
: "Shared::cluster destination requires address-space 7.");
// Checks specific to shared::cta mode
if (isCTAOnly) {
if (getMulticastMask())
return emitError("Multicast is not supported with shared::cta mode.");
if (getGroup())
return emitError("CTAGroup is not supported with shared::cta mode.");
}
}
return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
getMode(), getLoc());
}
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
TMAStoreMode mode = getMode();
size_t dims = getCoordinates().size();
switch (mode) {
case TMAStoreMode::TILE:
return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
case TMAStoreMode::IM2COL:
return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
case TMAStoreMode::TILE_SCATTER4:
return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
}
return success();
}
LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
if (isSharedCTA && getMulticastMask())
return emitError("Multicast is not supported with shared::cta mode.");
return success();
}
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
NVVM::MemScopeKind scope,
Value retVal = nullptr) {
if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
return op->emitError("mbarrier scope must be either CTA or Cluster");
bool isSharedCluster = isPtrInSharedClusterSpace(addr);
bool hasRetValue = static_cast<bool>(retVal);
if (isSharedCluster && hasRetValue)
return op->emitError(
"mbarrier in shared_cluster space cannot return any value");
return success();
}
LogicalResult MBarrierArriveOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
getRes());
}
LogicalResult MBarrierArriveDropOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
getRes());
}
LogicalResult MBarrierArriveExpectTxOp::verify() {
// The inline-ptx version of this Op does not support all features.
// With predicate, this Op lowers to inline-ptx. So, verify and
// error-out if there are unsupported features.
if (getPredicate()) {
if (getScope() != NVVM::MemScopeKind::CTA)
return emitError("mbarrier scope must be CTA when using predicate");
if (isPtrInSharedClusterSpace(getAddr()))
return emitError("mbarrier in shared_cluster space is not supported when "
"using predicate");
if (getRes())
return emitError("return-value is not supported when using predicate");
if (getRelaxed() == true)
return emitError("mbarrier with relaxed semantics is not supported when "
"using predicate");
}
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
getRes());
}
LogicalResult MBarrierArriveDropExpectTxOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
getRes());
}
LogicalResult MBarrierExpectTxOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
}
LogicalResult MBarrierCompleteTxOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
}
LogicalResult MBarrierTestWaitOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
}
LogicalResult MBarrierTryWaitOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
}
LogicalResult ConvertFloatToTF32Op::verify() {
using RndMode = NVVM::FPRoundingMode;
switch (getRnd()) {
case RndMode::RNA:
if (getRelu())
return emitError("Relu not supported with rna rounding mode.");
break;
case RndMode::RN:
case RndMode::RZ:
break;
default:
return emitError(
"Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
}
return success();
}
LogicalResult ConvertF32x2ToF6x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
return emitOpError("Only ")
<< mlir::Float6E2M3FNType::get(ctx) << " and "
<< mlir::Float6E3M2FNType::get(ctx)
<< " types are supported for conversions from f32x2 to f6x2.";
}
return success();
}
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
bool isRoundingModeRN = getRnd() == RndMode::RN;
bool isRoundingModeRZ = getRnd() == RndMode::RZ;
bool isRoundingModeRP = getRnd() == RndMode::RP;
bool isSatFinite = getSat() == SatMode::SATFINITE;
bool hasRelu = getRelu();
mlir::MLIRContext *ctx = getContext();
return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
.Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
[&](mlir::Type) -> LogicalResult {
if (!isRoundingModeRN) {
return emitOpError("Only RN rounding mode is supported for "
"conversions from f32x2 to ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx) << " types";
}
if (!isSatFinite) {
return emitOpError("Only SATFINITE saturation mode is supported "
"for conversions "
"from f32x2 to ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx) << " types";
}
return success();
})
.Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
if (!(isRoundingModeRZ || isRoundingModeRP)) {
return emitOpError("Only RZ and RP rounding modes are supported for "
"conversions from f32x2 to ")
<< mlir::Float8E8M0FNUType::get(ctx) << " type";
}
if (hasRelu) {
return emitOpError("relu not supported for conversions to ")
<< mlir::Float8E8M0FNUType::get(ctx) << " type";
}
return success();
})
.Default([&](mlir::Type) {
return emitOpError("Only ")
<< mlir::Float8E4M3FNType::get(ctx) << ", "
<< mlir::Float8E5M2Type::get(ctx) << ", and "
<< mlir::Float8E8M0FNUType::get(ctx)
<< " types are "
"supported for conversions from f32x2 to f8x2";
});
}
LogicalResult ConvertF16x2ToF8x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
return emitOpError("Only ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx)
<< " types are supported for conversions from f16x2 to f8x2.";
}
return success();
}
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
<< " type is supported for conversions from "
"bf16x2 to f8x2.";
auto rnd = getRnd();
if (rnd != RndMode::RZ && rnd != RndMode::RP)
return emitOpError("Only RZ and RP rounding modes are supported for "
"conversions from bf16x2 to f8x2.");
return success();
}
LogicalResult ConvertF32x2ToF4x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
return emitOpError("Only ")
<< mlir::Float4E2M1FNType::get(ctx)
<< " type is supported for conversions from f32x2 to f4x2.";
return success();
}
LogicalResult ConvertF8x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
return emitOpError("Only ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx)
<< " types are supported for conversions from f8x2 to f16x2.";
return success();
}
LogicalResult ConvertF8x2ToBF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
return emitOpError("Only ")
<< mlir::Float8E8M0FNUType::get(ctx)
<< " type is supported for conversions from f8x2 to bf16x2.";
return success();
}
LogicalResult ConvertF6x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
return emitOpError("Only ")
<< mlir::Float6E2M3FNType::get(ctx) << " and "
<< mlir::Float6E3M2FNType::get(ctx)
<< " types are supported for conversions from f6x2 to f16x2.";
return success();
}
LogicalResult ConvertF4x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
return emitOpError("Only ")
<< mlir::Float4E2M1FNType::get(ctx)
<< " type is supported for conversions from f4x2 to f16x2.";
return success();
}
LogicalResult PermuteOp::verify() {
using Mode = NVVM::PermuteMode;
bool hasHi = static_cast<bool>(getHi());
switch (getMode()) {
case Mode::DEFAULT:
case Mode::F4E:
case Mode::B4E:
if (!hasHi)
return emitError("mode '")
<< stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
break;
case Mode::RC8:
case Mode::ECL:
case Mode::ECR:
case Mode::RC16:
if (hasHi)
return emitError("mode '") << stringifyPermuteMode(getMode())
<< "' does not accept 'hi' operand.";
break;
}
return success();
}
//===----------------------------------------------------------------------===//
// Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
FPRoundingMode rnd,
bool hasRandomBits,
Operation *op) {
static constexpr FPRoundingMode validRndModes[] = {
FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
if (!llvm::is_contained(validRndModes, rnd)) {
return op->emitOpError(
"Only RN, RZ, and RS rounding modes are supported for "
"conversions from f32x2 to ")
<< dstType << ".";
}
if (rnd == FPRoundingMode::RS) {
if (!hasRandomBits) {
return op->emitOpError("random_bits is required for RS rounding mode.");
}
} else {
if (hasRandomBits) {
return op->emitOpError(
"random_bits not supported for RN and RZ rounding modes.");
}
}
return success();
}
LogicalResult ConvertF32x2ToF16x2Op::verify() {
return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
getRandomBits() ? true : false, *this);
}
LogicalResult ConvertF32x2ToBF16x2Op::verify() {
return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
getRandomBits() ? true : false, *this);
}
LogicalResult ConvertF32x4ToF8x4Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
return emitOpError("Only ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx)
<< " types are supported for conversions from f32x4 to f8x4.";
return success();
}
LogicalResult ConvertF32x4ToF6x4Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
return emitOpError("Only ")
<< mlir::Float6E2M3FNType::get(ctx) << " and "
<< mlir::Float6E3M2FNType::get(ctx)
<< " types are supported for conversions from f32x4 to f6x4.";
return success();
}
LogicalResult ConvertF32x4ToF4x4Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
<< " type is supported for conversions from "
"f32x4 to f4x4.";
return success();
}
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
return success();
}
LogicalResult PMEventOp::verify() {
auto eventId = getEventId();
auto maskedEventId = getMaskedEventId();
if (!maskedEventId && !eventId) {
return emitOpError() << "either `id` or `mask` must be set";
}
if (maskedEventId && eventId) {
return emitOpError() << "`id` and `mask` cannot be set at the same time";
}
if (eventId) {
if (eventId < 0 || eventId > 15) {
return emitOpError() << "`id` must be between 0 and 15";
}
}
return llvm::success();
}
// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
std::optional<mlir::NVVM::MMATypes>
MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
auto half2Type =
VectorType::get(2, Float16Type::get(operandElType.getContext()));
if (operandElType.isF64())
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
return NVVM::MMATypes::f16;
if (operandElType.isF32() && isAccumulator)
return NVVM::MMATypes::f32;
if (operandElType.isF32() && !isAccumulator)
return NVVM::MMATypes::tf32;
if (llvm::isa<IntegerType>(operandElType)) {
if (isAccumulator)
return NVVM::MMATypes::s32;
return std::nullopt;
}
if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
if (structType.getBody().empty())
return std::nullopt;
return inferOperandMMAType(structType.getBody()[0], isAccumulator);
}
return std::nullopt;
}
static bool isInt4PtxType(MMATypes type) {
return (type == MMATypes::u4 || type == MMATypes::s4);
}
static bool isInt8PtxType(MMATypes type) {
return (type == MMATypes::u8 || type == MMATypes::s8);
}
static bool isIntegerPtxType(MMATypes type) {
return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
type == MMATypes::s32;
}
MMATypes MmaOp::accumPtxType() {
std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
assert(val.has_value() && "accumulator PTX type should always be inferrable");
return val.value();
}
MMATypes MmaOp::resultPtxType() {
std::optional<mlir::NVVM::MMATypes> val =
inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
assert(val.has_value() && "result PTX type should always be inferrable");
return val.value();
}
void MmaOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
struct MMAOperandFragment {
StringRef operandName;
StringRef ptxTypeAttr;
SmallVector<Value, 4> regs;
explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
: operandName(name), ptxTypeAttr(ptxTypeName) {}
};
std::array<MMAOperandFragment, 3> frags{
MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
MMAOperandFragment("C", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
auto &frag = frags[fragIdx];
auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
for (auto operandIdx = varOperandSpec.first;
operandIdx < varOperandSpec.first + varOperandSpec.second;
operandIdx++) {
frag.regs.push_back(this->getOperand(operandIdx));
if (operandIdx == 0) {
regTypes.push_back(this->getOperand(operandIdx).getType());
}
}
std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
if (inferredType)
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
auto printMmaOperand = [&](const MMAOperandFragment &frag) -> void {
p << " " << frag.operandName;
p << "[";
p.printOperands(frag.regs);
p << "] ";
};
for (const auto &frag : frags) {
printMmaOperand(frag);
}
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
p << " : "
<< "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
p);
p << ")";
p.printArrowTypeList(TypeRange{this->getRes().getType()});
}
void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
ValueRange operandA, ValueRange operandB, ValueRange operandC,
ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
std::optional<MMAIntOverflow> intOverflow,
std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
MLIRContext *ctx = builder.getContext();
result.addAttribute(
"shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
result.addOperands(operandA);
result.addOperands(operandB);
result.addOperands(operandC);
if (multiplicandPtxTypes) {
result.addAttribute("multiplicandAPtxType",
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
result.addAttribute("multiplicandBPtxType",
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
} else {
if (auto res = inferOperandMMAType(operandA[0].getType(), false))
result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
if (auto res = inferOperandMMAType(operandB[0].getType(), false))
result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
}
if (multiplicandLayouts) {
result.addAttribute("layoutA",
MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
result.addAttribute("layoutB",
MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
} else {
result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
}
if (intOverflow.has_value())
result.addAttribute("intOverflowBehavior",
MMAIntOverflowAttr::get(ctx, *intOverflow));
if (b1Op.has_value())
result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
result.addTypes(resultType);
result.addAttribute(
MmaOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
static_cast<int32_t>(operandB.size()),
static_cast<int32_t>(operandC.size())}));
}
// <operation> :=
// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
// `->` type($res)
ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
struct MMAOperandFragment {
std::optional<MMATypes> elemtype;
SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
SmallVector<Type> regTypes;
};
Builder &builder = parser.getBuilder();
std::array<MMAOperandFragment, 4> frags;
NamedAttrList namedAttributes;
// A helper to parse the operand segments.
auto parseMmaOperand = [&](StringRef operandName,
MMAOperandFragment &frag) -> LogicalResult {
if (parser.parseKeyword(operandName).failed())
return failure();
if (parser
.parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
.failed())
return failure();
return success();
};
// Parse the operand segments.
if (parseMmaOperand("A", frags[0]).failed())
return failure();
if (parseMmaOperand("B", frags[1]).failed())
return failure();
if (parseMmaOperand("C", frags[2]).failed())
return failure();
if (parser.parseOptionalAttrDict(namedAttributes).failed())
return failure();
// Parse the type specification and resolve operands.
SmallVector<Type, 3> operandTypes;
if (failed(parser.parseColon()))
return failure();
if (failed(parser.parseLParen()))
return failure();
if (failed(parser.parseTypeList(operandTypes)))
return failure();
if (failed(parser.parseRParen()))
if (operandTypes.size() != 3)
return parser.emitError(
parser.getNameLoc(),
"expected one type for each operand segment but got " +
Twine(operandTypes.size()) + " types");
for (const auto &iter : llvm::enumerate(operandTypes)) {
auto &frag = frags[iter.index()];
frag.regTypes.resize(frag.regs.size(), iter.value());
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
parser.getNameLoc(), result.operands)))
return failure();
frag.elemtype = inferOperandMMAType(frag.regTypes[0],
/*isAccumulator*/ iter.index() < 2);
}
Type resultType;
if (parser.parseArrow() || parser.parseType(resultType))
return failure();
frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
std::array<StringRef, 2> names{"multiplicandAPtxType",
"multiplicandBPtxType"};
for (unsigned idx = 0; idx < names.size(); idx++) {
const auto &frag = frags[idx];
std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
if (!frag.elemtype.has_value() && !attr.has_value()) {
return parser.emitError(
parser.getNameLoc(),
"attribute " + names[idx] +
" is not provided explicitly and cannot be inferred");
}
if (!attr.has_value())
result.addAttribute(
names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
}
result.addTypes(resultType);
if (!namedAttributes.empty())
result.addAttributes(namedAttributes);
result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
static_cast<int32_t>(frags[0].regs.size()),
static_cast<int32_t>(frags[1].regs.size()),
static_cast<int32_t>(frags[2].regs.size()),
}));
return success();
}
LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto i32Ty = IntegerType::get(context, 32);
auto f16x2Ty = VectorType::get(2, f16Ty);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
auto s32x4StructTy =
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
auto f32x8StructTy =
LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
auto f16x2x2StructTy =
LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
auto f32x4StructTy =
LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
auto s32x2StructTy =
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
getShapeAttr().getK()};
// These variables define the set of allowed data types for matrices A, B, C,
// and result.
using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
AllowedShapes allowedShapes;
AllowedTypes expectedA;
AllowedTypes expectedB;
AllowedTypes expectedC;
SmallVector<Type> expectedResult;
// When M = 16, we just need to calculate the number of 8xk tiles, where
// k is a factor that depends on the data type.
if (mmaShape[0] == 16) {
int64_t kFactor;
Type multiplicandFragType;
switch (*getMultiplicandAPtxType()) {
case MMATypes::tf32:
kFactor = 4;
multiplicandFragType = i32Ty;
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
break;
case MMATypes::bf16:
kFactor = 8;
multiplicandFragType = i32Ty;
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
break;
case MMATypes::f16:
kFactor = 8;
multiplicandFragType = f16x2Ty;
expectedResult.push_back(f16x2x2StructTy);
expectedResult.push_back(f32x4StructTy);
break;
case MMATypes::s4:
case MMATypes::u4:
kFactor = 32;
break;
case MMATypes::b1:
kFactor = 128;
break;
case MMATypes::s8:
case MMATypes::u8:
kFactor = 16;
break;
default:
return emitError("invalid shape or multiplicand type: " +
stringifyEnum(getMultiplicandAPtxType().value()));
}
if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
expectedResult.push_back(s32x4StructTy);
expectedC.emplace_back(4, i32Ty);
multiplicandFragType = i32Ty;
} else {
expectedC.emplace_back(2, f16x2Ty);
expectedC.emplace_back(4, f32Ty);
}
int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
expectedA.emplace_back(unitA, multiplicandFragType);
expectedB.emplace_back(unitB, multiplicandFragType);
allowedShapes.push_back({16, 8, kFactor});
allowedShapes.push_back({16, 8, kFactor * 2});
if (resultPtxType() != accumPtxType())
return emitOpError("ctype does not match dtype");
}
// In the M=8 case, there is only 1 possible case per data type.
if (mmaShape[0] == 8) {
if (*getMultiplicandAPtxType() == MMATypes::f16) {
expectedA.emplace_back(2, f16x2Ty);
expectedB.emplace_back(2, f16x2Ty);
expectedResult.push_back(f16x2x4StructTy);
expectedResult.push_back(f32x8StructTy);
expectedC.emplace_back(4, f16x2Ty);
expectedC.emplace_back(8, f32Ty);
allowedShapes.push_back({8, 8, 4});
}
if (*getMultiplicandAPtxType() == MMATypes::f64) {
Type f64Ty = Float64Type::get(context);
expectedA.emplace_back(1, f64Ty);
expectedB.emplace_back(1, f64Ty);
expectedC.emplace_back(2, f64Ty);
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
context, SmallVector<Type>(2, f64Ty)));
allowedShapes.push_back({8, 8, 4});
}
if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
expectedA.push_back({i32Ty});
expectedB.push_back({i32Ty});
expectedC.push_back({i32Ty, i32Ty});
expectedResult.push_back(s32x2StructTy);
if (isInt4PtxType(getMultiplicandAPtxType().value()))
allowedShapes.push_back({8, 8, 32});
if (isInt8PtxType(getMultiplicandAPtxType().value()))
allowedShapes.push_back({8, 8, 16});
if (getMultiplicandAPtxType().value() == MMATypes::b1)
allowedShapes.push_back({8, 8, 128});
}
}
std::string errorMessage;
llvm::raw_string_ostream errorStream(errorMessage);
// Check that we matched an existing shape/dtype combination.
if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
!llvm::is_contained(allowedShapes, mmaShape)) {
errorStream << "unimplemented variant for MMA shape <";
llvm::interleaveComma(mmaShape, errorStream);
errorStream << ">";
return emitOpError(errorMessage);
}
// Verify the operand types for segments of A, B, and C operands.
std::array<StringRef, 3> operandNames{"A", "B", "C"};
for (const auto &iter : llvm::enumerate(
SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
auto spec = this->getODSOperandIndexAndLength(iter.index());
SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
operand_type_begin() + spec.first +
spec.second);
bool match = llvm::is_contained(iter.value(), operandTySeg);
if (!match) {
errorStream << "Could not match types for the "
<< operandNames[iter.index()]
<< " operands; expected one of ";
for (const auto &x : iter.value()) {
errorStream << x.size() << "x" << x[0] << " ";
}
errorStream << "but got ";
llvm::interleaveComma(operandTySeg, errorStream);
return emitOpError(errorMessage);
}
}
// Check the result type
if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
return expectedResultType == getResult().getType();
})) {
errorStream
<< "Could not match allowed types for the result; expected one of ";
llvm::interleaveComma(expectedResult, errorStream);
errorStream << " but got " << getResult().getType();
return emitOpError(errorMessage);
}
// Ensure that binary MMA variants have a b1 MMA operation defined.
if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
return emitOpError("op requires " + getB1OpAttrName().strref() +
" attribute");
}
// Ensure int4/int8 MMA variants specify the accum overflow behavior
// attribute.
if (isInt4PtxType(*getMultiplicandAPtxType()) ||
isInt8PtxType(*getMultiplicandAPtxType())) {
if (!getIntOverflowBehavior())
return emitOpError("op requires " +
getIntOverflowBehaviorAttrName().strref() +
" attribute");
}
// Validate layout combinations. According to the operation description, most
// MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
// can use other layout combinations.
bool isM8N8K4_F16 =
(mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
getMultiplicandAPtxType() == MMATypes::f16);
if (!isM8N8K4_F16) {
// For all other shapes/types, layoutA must be row and layoutB must be col
if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
"layoutB = #nvvm.mma_layout<col> for shape <")
<< mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
<< "> with element types "
<< stringifyEnum(*getMultiplicandAPtxType()) << " and "
<< stringifyEnum(*getMultiplicandBPtxType())
<< ". Only m8n8k4 with f16 supports other layouts.";
}
}
return success();
}
MMATypes MmaSpOp::accumPtxType() {
std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
assert(val.has_value() && "accumulator PTX type should always be inferrable");
return val.value();
}
MMATypes MmaSpOp::resultPtxType() {
std::optional<mlir::NVVM::MMATypes> val =
MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
assert(val.has_value() && "result PTX type should always be inferrable");
return val.value();
}
mlir::NVVM::IDArgPair
MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MmaSpOp>(op);
// Get operands
llvm::SmallVector<llvm::Value *> args;
for (mlir::Value v : thisOp.getOperands())
args.push_back(mt.lookupValue(v));
// Get intrinsic ID using the existing getIntrinsicID method
auto intId = MmaSpOp::getIntrinsicID(
thisOp.getShape().getM(), thisOp.getShape().getN(),
thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
thisOp.getOrderedMetadata(), thisOp.getKind(),
*thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
thisOp.accumPtxType(), thisOp.resultPtxType());
return {intId, args};
}
void MmaSpOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
struct MMAOperandFragment {
StringRef operandName;
StringRef ptxTypeAttr;
SmallVector<Value, 4> regs;
explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
: operandName(name), ptxTypeAttr(ptxTypeName) {}
};
std::array<MMAOperandFragment, 5> frags{
MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
MMAOperandFragment("C", ""), MMAOperandFragment("sparseMetadata", ""),
MMAOperandFragment("selector", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
// Handle variadic operands A, B, C
for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
auto &frag = frags[fragIdx];
auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
for (auto operandIdx = varOperandSpec.first;
operandIdx < varOperandSpec.first + varOperandSpec.second;
operandIdx++) {
frag.regs.push_back(this->getOperand(operandIdx));
if (operandIdx == varOperandSpec.first) {
regTypes.push_back(this->getOperand(operandIdx).getType());
}
}
std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
if (inferredType)
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
// Handle sparse metadata and selector (single operands)
frags[3].regs.push_back(getSparseMetadata());
frags[4].regs.push_back(getSparsitySelector());
auto printMmaSpOperand = [&](const MMAOperandFragment &frag) -> void {
p << " " << frag.operandName;
p << "[";
p.printOperands(frag.regs);
p << "]";
};
for (const auto &frag : frags)
printMmaSpOperand(frag);
p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
p << " : ";
p << "(";
for (int i = 0; i < 3; ++i) {
p << regTypes[i];
if (i < 2)
p << ", ";
}
p << ") -> " << getResult().getType();
}
void MmaSpOp::build(
OpBuilder &builder, OperationState &result, Type resultType,
ValueRange operandA, ValueRange operandB, ValueRange operandC,
Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
std::optional<MMAIntOverflow> intOverflow,
std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
MLIRContext *ctx = builder.getContext();
result.addAttribute(
"shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
result.addOperands(operandA);
result.addOperands(operandB);
result.addOperands(operandC);
result.addOperands(sparseMetadata);
result.addOperands(sparsitySelector);
if (multiplicandPtxTypes) {
result.addAttribute("multiplicandAPtxType",
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
result.addAttribute("multiplicandBPtxType",
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
} else {
if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
}
if (intOverflow.has_value())
result.addAttribute("intOverflowBehavior",
MMAIntOverflowAttr::get(ctx, *intOverflow));
result.addTypes(resultType);
result.addAttribute(
MmaSpOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
static_cast<int32_t>(operandB.size()),
static_cast<int32_t>(operandC.size()), 1,
1})); // sparseMetadata and sparsitySelector
}
ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
struct MMAOperandFragment {
std::optional<MMATypes> elemtype;
SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
SmallVector<Type> regTypes;
};
Builder &builder = parser.getBuilder();
std::array<MMAOperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
NamedAttrList namedAttributes;
// A helper to parse the operand segments.
auto parseMmaSpOperand = [&](StringRef operandName,
MMAOperandFragment &frag) -> LogicalResult {
if (parser.parseKeyword(operandName).failed())
return failure();
if (parser
.parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
.failed())
return failure();
return success();
};
// Parse the operand segments.
if (parseMmaSpOperand("A", frags[0]).failed())
return failure();
if (parseMmaSpOperand("B", frags[1]).failed())
return failure();
if (parseMmaSpOperand("C", frags[2]).failed())
return failure();
if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
return failure();
if (parseMmaSpOperand("selector", frags[4]).failed())
return failure();
if (parser.parseOptionalAttrDict(namedAttributes).failed())
return failure();
// Parse the type specification and resolve operands.
SmallVector<Type, 3> operandTypes;
if (failed(parser.parseColon()))
return failure();
if (failed(parser.parseLParen()))
return failure();
if (failed(parser.parseTypeList(operandTypes)))
return failure();
if (failed(parser.parseRParen()))
return failure();
if (operandTypes.size() != 3)
return parser.emitError(
parser.getNameLoc(),
"expected one type for each operand segment but got " +
Twine(operandTypes.size()) + " types");
for (const auto &iter : llvm::enumerate(operandTypes)) {
auto &frag = frags[iter.index()];
frag.regTypes.resize(frag.regs.size(), iter.value());
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
parser.getNameLoc(), result.operands)))
return failure();
frag.elemtype =
MmaOp::inferOperandMMAType(frag.regTypes[0],
/*isAccumulator*/ iter.index() >= 2);
}
Type resultType;
if (parser.parseArrow() || parser.parseType(resultType))
return failure();
frags[5].elemtype =
MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
// Resolve sparse metadata and selector (assume i32 type)
Type i32Type = builder.getIntegerType(32);
if (parser
.resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
result.operands)
.failed())
return failure();
if (parser
.resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
result.operands)
.failed())
return failure();
std::array<StringRef, 2> names{"multiplicandAPtxType",
"multiplicandBPtxType"};
for (unsigned idx = 0; idx < names.size(); idx++) {
const auto &frag = frags[idx];
std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
if (!frag.elemtype.has_value() && !attr.has_value()) {
return parser.emitError(
parser.getNameLoc(),
"attribute " + names[idx] +
" is not provided explicitly and cannot be inferred");
}
if (!attr.has_value())
result.addAttribute(
names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
}
result.addTypes(resultType);
if (!namedAttributes.empty())
result.addAttributes(namedAttributes);
result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
static_cast<int32_t>(frags[0].regs.size()),
static_cast<int32_t>(frags[1].regs.size()),
static_cast<int32_t>(frags[2].regs.size()),
1, // sparseMetadata
1 // sparsitySelector
}));
return success();
}
LogicalResult MmaSpOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto i32Ty = IntegerType::get(context, 32);
auto f16x2Ty = VectorType::get(2, f16Ty);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
auto s32x4StructTy =
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
auto f32x8StructTy =
LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
auto f16x2x2StructTy =
LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
auto f32x4StructTy =
LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
auto s32x2StructTy =
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
getShapeAttr().getK()};
// These variables define the set of allowed data types for matrices A, B, C,
// and result.
using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
AllowedShapes allowedShapes;
AllowedTypes expectedA;
AllowedTypes expectedB;
AllowedTypes expectedC;
SmallVector<Type> expectedResult;
// When M = 16, we just need to calculate the number of 8xk tiles, where
// k is a factor that depends on the data type.
if (mmaShape[0] == 16) {
int64_t kFactor;
Type multiplicandFragType;
switch (*getMultiplicandAPtxType()) {
case MMATypes::tf32:
kFactor = 4;
multiplicandFragType = i32Ty;
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
// Sparse MMA supports m16n8k8 and m16n8k16 for tf32
allowedShapes.push_back({16, 8, 8});
allowedShapes.push_back({16, 8, 16});
break;
case MMATypes::bf16:
kFactor = 8;
multiplicandFragType = i32Ty;
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
// Sparse MMA supports m16n8k16 and m16n8k32 for bf16
allowedShapes.push_back({16, 8, 16});
allowedShapes.push_back({16, 8, 32});
break;
case MMATypes::f16:
kFactor = 8;
multiplicandFragType = f16x2Ty;
expectedResult.push_back(f16x2x2StructTy);
expectedResult.push_back(f32x4StructTy);
// Sparse MMA supports m16n8k16 and m16n8k32 for f16
allowedShapes.push_back({16, 8, 16});
allowedShapes.push_back({16, 8, 32});
break;
case MMATypes::s4:
case MMATypes::u4:
kFactor = 32;
// Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
allowedShapes.push_back({16, 8, 64});
allowedShapes.push_back({16, 8, 128});
break;
case MMATypes::s8:
case MMATypes::u8:
kFactor = 16;
// Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
allowedShapes.push_back({16, 8, 32});
allowedShapes.push_back({16, 8, 64});
break;
case MMATypes::e4m3:
case MMATypes::e5m2:
case MMATypes::e3m2:
case MMATypes::e2m3:
case MMATypes::e2m1:
kFactor = 32;
multiplicandFragType = i32Ty;
expectedResult.push_back(f16x2x2StructTy);
expectedResult.push_back(f32x4StructTy);
// Sparse MMA supports m16n8k64 for FP8 types
allowedShapes.push_back({16, 8, 64});
break;
default:
return emitError("invalid shape or multiplicand type: " +
stringifyEnum(getMultiplicandAPtxType().value()));
}
if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
expectedResult.push_back(s32x4StructTy);
expectedC.emplace_back(4, i32Ty);
multiplicandFragType = i32Ty;
} else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
*getMultiplicandAPtxType() <= MMATypes::e2m1) {
// FP8 types
expectedC.emplace_back(2, f16x2Ty);
expectedC.emplace_back(4, f32Ty);
} else {
expectedC.emplace_back(2, f16x2Ty);
expectedC.emplace_back(4, f32Ty);
}
// For sparse MMA, A operand is compressed (2:4 sparsity means half the
// elements)
int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
expectedA.emplace_back(unitA, multiplicandFragType);
expectedB.emplace_back(unitB, multiplicandFragType);
if (resultPtxType() != accumPtxType())
return emitOpError("ctype does not match dtype");
}
// In the M=8 case, there is only 1 possible case per data type.
if (mmaShape[0] == 8) {
if (*getMultiplicandAPtxType() == MMATypes::f16) {
expectedA.emplace_back(2, f16x2Ty);
expectedB.emplace_back(2, f16x2Ty);
expectedResult.push_back(f16x2x4StructTy);
expectedResult.push_back(f32x8StructTy);
expectedC.emplace_back(4, f16x2Ty);
expectedC.emplace_back(8, f32Ty);
allowedShapes.push_back({8, 8, 4});
}
if (*getMultiplicandAPtxType() == MMATypes::f64) {
Type f64Ty = Float64Type::get(context);
expectedA.emplace_back(1, f64Ty);
expectedB.emplace_back(1, f64Ty);
expectedC.emplace_back(2, f64Ty);
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
context, SmallVector<Type>(2, f64Ty)));
allowedShapes.push_back({8, 8, 4});
}
if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
expectedA.push_back({i32Ty});
expectedB.push_back({i32Ty});
expectedC.push_back({i32Ty, i32Ty});
expectedResult.push_back(s32x2StructTy);
if (isInt4PtxType(getMultiplicandAPtxType().value()))
allowedShapes.push_back({8, 8, 32});
if (isInt8PtxType(getMultiplicandAPtxType().value()))
allowedShapes.push_back({8, 8, 16});
}
}
std::string errorMessage;
llvm::raw_string_ostream errorStream(errorMessage);
// Check that we matched an existing shape/dtype combination.
if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
!llvm::is_contained(allowedShapes, mmaShape)) {
errorStream << "unimplemented variant for MMA shape <";
llvm::interleaveComma(mmaShape, errorStream);
errorStream << ">";
return emitOpError(errorMessage);
}
// Verify the operand types for segments of A, B, and C operands.
std::array<StringRef, 3> operandNames{"A", "B", "C"};
for (const auto &iter : llvm::enumerate(
SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
auto spec = this->getODSOperandIndexAndLength(iter.index());
SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
operand_type_begin() + spec.first +
spec.second);
bool match = llvm::is_contained(iter.value(), operandTySeg);
if (!match) {
errorStream << "Could not match types for the "
<< operandNames[iter.index()]
<< " operands; expected one of ";
for (const auto &x : iter.value()) {
errorStream << x.size() << "x" << x[0] << " ";
}
errorStream << "but got ";
llvm::interleaveComma(operandTySeg, errorStream);
return emitOpError(errorMessage);
}
}
// Check the result type
if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
return expectedResultType == getResult().getType();
})) {
errorStream
<< "Could not match allowed types for the result; expected one of ";
llvm::interleaveComma(expectedResult, errorStream);
errorStream << " but got " << getResult().getType();
return emitOpError(errorMessage);
}
// Ensure int4/int8 MMA variants specify the accum overflow behavior
// attribute.
if (isInt4PtxType(*getMultiplicandAPtxType()) ||
isInt8PtxType(*getMultiplicandAPtxType())) {
if (!getIntOverflowBehavior())
return emitOpError("op requires " +
getIntOverflowBehaviorAttrName().strref() +
" attribute");
}
// Validate sparse metadata type (should be i32)
if (!getSparseMetadata().getType().isInteger(32)) {
return emitOpError() << "sparse metadata must be i32 type";
}
// Validate sparsity selector type (should be i32)
if (!getSparsitySelector().getType().isInteger(32)) {
return emitOpError() << "sparsity selector must be i32 type";
}
return success();
}
//===----------------------------------------------------------------------===//
// MMA Block Scale Operations - Shared Helpers
//===----------------------------------------------------------------------===//
namespace {
// Shared structure for MMA operand fragments (A, B, C)
struct MMAOperandFragment {
StringRef operandName;
StringRef ptxTypeAttr;
SmallVector<Value, 4> regs;
explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
: operandName(name), ptxTypeAttr(ptxTypeName) {}
};
} // namespace
// Helper to print operand list in the format: name[operands]
static void printOperandList(OpAsmPrinter &p, StringRef name,
ArrayRef<Value> operands) {
p << " " << name << "[";
p.printOperands(operands);
p << "]";
}
// Helper to parse operand list in the format: name[operands]
static LogicalResult
parseMmaOperand(OpAsmParser &parser, StringRef operandName,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &regs) {
if (parser.parseKeyword(operandName).failed())
return failure();
if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
.failed())
return failure();
return success();
}
// Helper to process operand fragments and determine which attributes can be
// inferred
template <typename Op>
static void
processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
SmallVectorImpl<Type> &regTypes,
SmallVectorImpl<StringRef> &ignoreAttrNames) {
for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
auto &frag = frags[fragIdx];
auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
for (auto operandIdx = varOperandSpec.first;
operandIdx < varOperandSpec.first + varOperandSpec.second;
operandIdx++) {
frag.regs.push_back(op.getOperand(operandIdx));
if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
regTypes.push_back(op.getOperand(operandIdx).getType());
}
}
if (fragIdx < 2) {
regTypes.push_back(frag.regs[0].getType());
}
std::optional<MMATypes> inferredType =
MmaOp::inferOperandMMAType(regTypes.back(),
/*isAccumulator=*/fragIdx >= 2);
if (inferredType)
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
}
// Helper to parse type signature: (A_type, B_type, C_type)
static LogicalResult
parseMmaTypeSignature(OpAsmParser &parser,
SmallVectorImpl<Type> &operandTypes) {
if (parser.parseColon().failed() || parser.parseLParen().failed())
return failure();
auto typeParser = [&]() {
Type ty;
if (parser.parseType(ty).failed())
return failure();
operandTypes.push_back(ty);
return success();
};
if (parser.parseCommaSeparatedList(typeParser))
return failure();
if (operandTypes.size() != 3)
return parser.emitError(parser.getCurrentLocation(),
"expected exactly 3 types");
return parser.parseRParen();
}
// Helper to infer and set multiplicand PTX type attributes
static void
inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
const SmallVectorImpl<Type> &operandTypes) {
if (!attrs.get("multiplicandAPtxType")) {
if (auto inferredType =
MmaOp::inferOperandMMAType(operandTypes[0], false)) {
attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
}
}
if (!attrs.get("multiplicandBPtxType")) {
if (auto inferredType =
MmaOp::inferOperandMMAType(operandTypes[1], false)) {
attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
}
}
}
// Helper to add common block scale properties
template <typename OpType>
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> shape,
ScaleVecSize scaleVecSize,
BlockScaleFormat blockScaleFormat,
MMABlockScaleKind kind) {
MLIRContext *ctx = builder.getContext();
auto &properties = result.getOrAddProperties<typename OpType::Properties>();
properties.setShape(
builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
properties.setBlockScaleFormat(
BlockScaleFormatAttr::get(ctx, blockScaleFormat));
properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
}
// Helper to infer and add multiplicand PTX types to builder
static void addInferredMultiplicandTypes(
MLIRContext *ctx, OperationState &result, ValueRange operandA,
ValueRange operandB,
std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
if (multiplicandPtxTypes) {
result.addAttribute("multiplicandAPtxType",
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
result.addAttribute("multiplicandBPtxType",
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
} else {
if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
}
}
// Template helper for common accumPtxType/resultPtxType implementation
template <typename OpTy>
static MMATypes inferPtxTypeFromResult(OpTy op) {
return *MmaOp::inferOperandMMAType(
cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
/*isAccumulator=*/true);
}
//===----------------------------------------------------------------------===//
// MmaBlockScaleOp
//===----------------------------------------------------------------------===//
void MmaBlockScaleOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
std::array<MMAOperandFragment, 3> frags{
MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
MMAOperandFragment("C", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
// Print A, B, C operands
for (const auto &frag : frags)
printOperandList(p, frag.operandName, frag.regs);
// Print scale operands
printOperandList(p, "scaleA",
{getScaleAData(), getByteIdA(), getThreadIdA()});
printOperandList(p, "scaleB",
{getScaleBData(), getByteIdB(), getThreadIdB()});
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print type signature
p << " : (";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
p);
p << ")";
p.printArrowTypeList(TypeRange{this->getRes().getType()});
}
ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
OperationState &result) {
struct LocalOperandFragment {
std::optional<MMATypes> elemtype;
SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
};
Builder &builder = parser.getBuilder();
std::array<LocalOperandFragment, 3> frags;
NamedAttrList namedAttributes;
// Parse A[...] B[...] C[...]
if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
parseMmaOperand(parser, "B", frags[1].regs).failed() ||
parseMmaOperand(parser, "C", frags[2].regs).failed())
return failure();
// Parse scale operands: scaleA[...] scaleB[...]
SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
return failure();
if (parser.parseOptionalAttrDict(namedAttributes).failed())
return failure();
// Parse type signature
SmallVector<Type, 3> operandTypes;
if (parseMmaTypeSignature(parser, operandTypes).failed())
return failure();
// Parse result type
SmallVector<Type, 1> resultTypes;
if (parser.parseArrowTypeList(resultTypes).failed())
return failure();
// Infer element types and resolve operands
for (const auto &[idx, frag] : llvm::enumerate(frags)) {
frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
/*isAccumulator=*/idx >= 2);
if (parser
.resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
result.operands)
.failed())
return failure();
}
// Resolve scale operands
SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
builder.getI16Type()};
if (parser
.resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
result.operands)
.failed() ||
parser
.resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
result.operands)
.failed())
return failure();
// Add attributes
result.addAttributes(namedAttributes);
inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
operandTypes);
result.addTypes(resultTypes);
result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
static_cast<int32_t>(frags[0].regs.size()),
static_cast<int32_t>(frags[1].regs.size()),
static_cast<int32_t>(frags[2].regs.size()),
1, // scaleAData
1, // byteIdA
1, // threadIdA
1, // scaleBData
1, // byteIdB
1 // threadIdB
}));
return success();
}
void MmaBlockScaleOp::build(
OpBuilder &builder, OperationState &result, Type resultType,
ValueRange operandA, ValueRange operandB, ValueRange operandC,
Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
addBlockScaleProperties<MmaBlockScaleOp>(builder, result, shape, scaleVecSize,
blockScaleFormat, kind);
result.addOperands(operandA);
result.addOperands(operandB);
result.addOperands(operandC);
result.addOperands(
{scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
multiplicandPtxTypes);
result.addTypes(resultType);
result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
static_cast<int32_t>(operandA.size()),
static_cast<int32_t>(operandB.size()),
static_cast<int32_t>(operandC.size()),
1, // scaleAData
1, // byteIdA
1, // threadIdA
1, // scaleBData
1, // byteIdB
1 // threadIdB
}));
}
NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
SmallVector<llvm::Value *> args;
// Add A, B, C operands
for (Value operand : curOp.getOperandA())
args.push_back(mt.lookupValue(operand));
for (Value operand : curOp.getOperandB())
args.push_back(mt.lookupValue(operand));
for (Value operand : curOp.getOperandC())
args.push_back(mt.lookupValue(operand));
// Add scale operands
args.push_back(mt.lookupValue(curOp.getScaleAData()));
args.push_back(mt.lookupValue(curOp.getByteIdA()));
args.push_back(mt.lookupValue(curOp.getThreadIdA()));
args.push_back(mt.lookupValue(curOp.getScaleBData()));
args.push_back(mt.lookupValue(curOp.getByteIdB()));
args.push_back(mt.lookupValue(curOp.getThreadIdB()));
unsigned intId = MmaBlockScaleOp::getIntrinsicID(
curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
*curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
curOp.getBlockScaleFormat(), curOp.getKind());
return {intId, args};
}
LogicalResult MmaBlockScaleOp::verify() {
LogicalResult result = success();
int m = getShape().getM();
int n = getShape().getN();
int k = getShape().getK();
if (m == 16 && n == 8 && k == 64) {
if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
result = emitOpError(
"unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
result = emitOpError(
"unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
result = emitOpError(
"unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
} else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
(getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k64.mxf4nvf4");
} else {
result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
}
} else if (m == 16 && n == 8 && k == 32) {
if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
result =
emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k32");
} else {
result = emitOpError("unsupported Geom for mma with block scaling");
}
return result;
}
//===----------------------------------------------------------------------===//
// MmaSpBlockScaleOp
//===----------------------------------------------------------------------===//
void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
std::array<MMAOperandFragment, 3> frags{
MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
MMAOperandFragment("C", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
// Print A, B, C operands
for (const auto &frag : frags)
printOperandList(p, frag.operandName, frag.regs);
// Print sparse-specific operands
printOperandList(p, "sparseMetadata", {getSparseMetadata()});
printOperandList(p, "selector", {getSparsitySelector()});
// Print scale operands
printOperandList(p, "scaleA",
{getScaleAData(), getByteIdA(), getThreadIdA()});
printOperandList(p, "scaleB",
{getScaleBData(), getByteIdB(), getThreadIdB()});
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print type signature
p << " : (";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
p);
p << ")";
p.printArrowTypeList(TypeRange{this->getRes().getType()});
}
ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
OperationState &result) {
struct LocalOperandFragment {
std::optional<MMATypes> elemtype;
SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
};
Builder &builder = parser.getBuilder();
std::array<LocalOperandFragment, 3> frags;
NamedAttrList namedAttributes;
// Parse A[...] B[...] C[...]
if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
parseMmaOperand(parser, "B", frags[1].regs).failed() ||
parseMmaOperand(parser, "C", frags[2].regs).failed())
return failure();
// Parse sparse-specific operands
SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands,
selectorOperands;
if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
parseMmaOperand(parser, "selector", selectorOperands).failed())
return failure();
// Parse scale operands
SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
return failure();
if (parser.parseOptionalAttrDict(namedAttributes).failed())
return failure();
// Parse type signature
SmallVector<Type, 3> operandTypes;
if (parseMmaTypeSignature(parser, operandTypes).failed())
return failure();
// Parse result type
SmallVector<Type, 1> resultTypes;
if (parser.parseArrowTypeList(resultTypes).failed())
return failure();
// Infer element types and resolve operands
for (const auto &[idx, frag] : llvm::enumerate(frags)) {
frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
/*isAccumulator=*/idx >= 2);
if (parser
.resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
result.operands)
.failed())
return failure();
}
// Resolve sparse metadata and selector
Type i32Type = builder.getI32Type();
if (parser
.resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
result.operands)
.failed() ||
parser
.resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
result.operands)
.failed())
return failure();
// Resolve scale operands
SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
builder.getI16Type()};
if (parser
.resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
result.operands)
.failed() ||
parser
.resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
result.operands)
.failed())
return failure();
// Add attributes
result.addAttributes(namedAttributes);
inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
operandTypes);
// orderedMetadata is mandatory
if (!result.attributes.get("orderedMetadata"))
result.addAttribute("orderedMetadata", builder.getUnitAttr());
result.addTypes(resultTypes);
result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
static_cast<int32_t>(frags[0].regs.size()),
static_cast<int32_t>(frags[1].regs.size()),
static_cast<int32_t>(frags[2].regs.size()),
1, // sparseMetadata
1, // sparsitySelector
1, // scaleAData
1, // byteIdA
1, // threadIdA
1, // scaleBData
1, // byteIdB
1 // threadIdB
}));
return success();
}
void MmaSpBlockScaleOp::build(
OpBuilder &builder, OperationState &result, Type resultType,
ValueRange operandA, ValueRange operandB, ValueRange operandC,
Value sparseMetadata, Value sparsitySelector, Value scaleAData,
Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
Value threadIdB, ArrayRef<int64_t> shape,
std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
addBlockScaleProperties<MmaSpBlockScaleOp>(
builder, result, shape, scaleVecSize, blockScaleFormat, kind);
result.addAttribute("orderedMetadata", builder.getUnitAttr());
result.addOperands(operandA);
result.addOperands(operandB);
result.addOperands(operandC);
result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
threadIdA, scaleBData, byteIdB, threadIdB});
addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
multiplicandPtxTypes);
result.addTypes(resultType);
result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
static_cast<int32_t>(operandA.size()),
static_cast<int32_t>(operandB.size()),
static_cast<int32_t>(operandC.size()),
1, // sparseMetadata
1, // sparsitySelector
1, // scaleAData
1, // byteIdA
1, // threadIdA
1, // scaleBData
1, // byteIdB
1 // threadIdB
}));
}
NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
SmallVector<llvm::Value *> args;
// Add A, B, C operands
for (Value operand : curOp.getOperandA())
args.push_back(mt.lookupValue(operand));
for (Value operand : curOp.getOperandB())
args.push_back(mt.lookupValue(operand));
for (Value operand : curOp.getOperandC())
args.push_back(mt.lookupValue(operand));
// Add sparse metadata and selector
args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
// Add scale operands
args.push_back(mt.lookupValue(curOp.getScaleAData()));
args.push_back(mt.lookupValue(curOp.getByteIdA()));
args.push_back(mt.lookupValue(curOp.getThreadIdA()));
args.push_back(mt.lookupValue(curOp.getScaleBData()));
args.push_back(mt.lookupValue(curOp.getByteIdB()));
args.push_back(mt.lookupValue(curOp.getThreadIdB()));
unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
*curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
curOp.getBlockScaleFormat(), curOp.getKind());
return {intId, args};
}
LogicalResult MmaSpBlockScaleOp::verify() {
// Check that orderedMetadata is present
if (!getOrderedMetadata()) {
return emitOpError("'orderedMetadata' attribute is mandatory");
}
LogicalResult result = success();
int m = getShape().getM();
int n = getShape().getN();
int k = getShape().getK();
if (m == 16 && n == 8 && k == 128) {
if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
result = emitOpError(
"unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
result = emitOpError(
"unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
result = emitOpError(
"unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
} else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
(getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k128.mxf4nvf4");
} else {
result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
}
} else if (m == 16 && n == 8 && k == 64) {
if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
result =
emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k64");
} else {
result = emitOpError("unsupported Geom for sparse mma with block scaling");
}
return result;
}
LogicalResult ShflOp::verify() {
auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
auto verifyTypeError = [&](Twine desc, Type expectedType,
Type actualType) -> LogicalResult {
return emitOpError("expected " + desc + " to be of type ")
<< expectedType << " but got " << actualType << " instead";
};
if (returnStructType) {
if (!getReturnValueAndIsValid())
return emitOpError("\"return_value_and_is_valid\" attribute must be "
"specified when the return type is a struct type");
if (returnStructType.getBody().size() != 2)
return emitOpError("expected return type to be a two-element struct");
llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
auto resultType = returnStruct[0];
if (resultType != getVal().getType())
return verifyTypeError("first element in the returned struct",
getVal().getType(), resultType);
auto predicateType = returnStruct[1];
if (!predicateType.isInteger(1))
return verifyTypeError("second element in the returned struct",
mlir::IntegerType::get(getContext(), 1),
predicateType);
} else {
if (getReturnValueAndIsValid())
return emitOpError("expected return type to be a two-element struct");
if (getType() != getVal().getType())
return verifyTypeError("return type", getVal().getType(), getType());
}
return success();
}
std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
NVVM::MMAFrag frag, int nRow,
int nCol,
MLIRContext *context) {
unsigned numberElements = 0;
Type elementType;
OpBuilder builder(context);
Type f16x2 = VectorType::get(2, builder.getF16Type());
if (type == NVVM::MMATypes::f16) {
elementType = f16x2;
if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
numberElements = 8;
else
numberElements = 4;
} else if (type == NVVM::MMATypes::f32) {
elementType = builder.getF32Type();
numberElements = 8;
} else if (type == NVVM::MMATypes::f64) {
elementType = builder.getF64Type();
if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
numberElements = 1;
else
numberElements = 2;
} else if (type == NVVM::MMATypes::tf32) {
elementType = builder.getI32Type();
numberElements = 4;
} else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
elementType = builder.getI32Type();
int parallelSize = 0;
if (frag == NVVM::MMAFrag::a)
parallelSize = nRow;
if (frag == NVVM::MMAFrag::b)
parallelSize = nCol;
// m == 16 && n == 16 && k == 16
if (parallelSize == 16)
numberElements = 2;
// m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
else if (parallelSize == 8)
numberElements = 1;
else if (parallelSize == 32)
numberElements = 4;
} else if (type == NVVM::MMATypes::s32) {
elementType = builder.getI32Type();
numberElements = 8;
}
assert(numberElements != 0 && elementType != nullptr);
return std::make_pair(elementType, numberElements);
}
static std::pair<mlir::Type, unsigned>
inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
int k, MLIRContext *context) {
int nRow, nCol;
if (frag == NVVM::MMAFrag::a) {
nRow = m;
nCol = k;
} else if (frag == NVVM::MMAFrag::b) {
nRow = k;
nCol = n;
} else {
nRow = m;
nCol = n;
}
assert(nRow && nCol);
return inferMMAType(type, frag, nRow, nCol, context);
}
LogicalResult NVVM::WMMALoadOp::verify() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
addressSpace != NVVMMemorySpace::Shared)
return emitOpError("expected source pointer in memory "
"space 0, 1, 3");
if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
getEltype(), getFrag()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
getEltype(), getFrag(), getM(), getN(), getK(), getContext());
// Special case for f64 fragments
Type f64Ty = Float64Type::get(getContext());
if (typeInfo.first == f64Ty && typeInfo.second == 1) {
if (getType() != f64Ty)
return emitOpError("expected destination type to be f64");
return success();
}
// Everything else is a struct
Type dstType = LLVM::LLVMStructType::getLiteral(
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
<< typeInfo.second << " elements of type " << typeInfo.first;
return success();
}
LogicalResult NVVM::WMMAStoreOp::verify() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
addressSpace != NVVMMemorySpace::Shared)
return emitOpError("expected operands to be a source pointer in memory "
"space 0, 1, 3");
if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
getEltype()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
if (getArgs().size() != typeInfo.second)
return emitOpError() << "expected " << typeInfo.second << " data operands";
if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
return operands.getType() != typeInfo.first;
}))
return emitOpError() << "expected data operands of type " << typeInfo.first;
return success();
}
LogicalResult NVVM::WMMAMmaOp::verify() {
if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
getLayoutB(), getEltypeA(),
getEltypeB()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
SmallVector<Type, 32> arguments;
arguments.append(typeInfoA.second, typeInfoA.first);
arguments.append(typeInfoB.second, typeInfoB.first);
arguments.append(typeInfoC.second, typeInfoC.first);
unsigned numArgs = arguments.size();
if (getArgs().size() != numArgs)
return emitOpError() << "expected " << numArgs << " arguments";
for (unsigned i = 0; i < numArgs; i++) {
if (getArgs()[i].getType() != arguments[i])
return emitOpError() << "expected argument " << i << " to be of type "
<< arguments[i];
}
Type dstType = LLVM::LLVMStructType::getLiteral(
getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
<< typeInfoC.second << " elements of type " << typeInfoC.first;
return success();
}
LogicalResult NVVM::LdMatrixOp::verify() {
uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
if (m == 8 && n == 8) {
if (num != 1 && num != 2 && num != 4) {
return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
"matrix");
}
if (getEltType() != LdStMatrixEltType::B16) {
return emitOpError("expected element type to be b16 for 8x8 matrix");
}
} else if (m == 8 && n == 16) {
if (num != 1 && num != 2 && num != 4) {
return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
"matrix");
}
if (getLayout() != MMALayout::row) {
return emitOpError("expected layout to be row for 8x16 matrix");
}
if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
return emitOpError("expected element type to be b8x16.b4x16_p64 or "
"b8x16.b6x16_p32 for 8x16 matrix");
}
} else if (m == 16 && n == 16) {
if (num != 1 && num != 2) {
return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
"matrix");
}
if (getLayout() != MMALayout::col) {
return emitOpError("expected layout to be col for 16x16 matrix");
}
if (getEltType() != LdStMatrixEltType::B8 &&
getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
"b8x16.b6x16_p32 for 16x16 matrix");
}
} else {
return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
}
Type i32 = IntegerType::get(getContext(), 32);
uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
if (numElements == 1 && getType() != i32)
return emitOpError("expected destination type is i32");
if (numElements == 2 || numElements == 4) {
Type dstType = LLVM::LLVMStructType::getLiteral(
getContext(), SmallVector<Type>(numElements, i32));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
<< numElements << " elements of type i32";
}
return success();
}
LogicalResult NVVM::StMatrixOp::verify() {
int numMatrix = getSources().size();
if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
return emitOpError("expected num attribute to be 1, 2 or 4");
int m = getShape().getM(), n = getShape().getN();
if (m == 8 && n == 8) {
if (getEltType() != NVVM::LdStMatrixEltType::B16) {
return emitOpError("expected element type to be B16 for 8x8 matrix");
}
} else if (m == 16 && n == 8) {
if (getEltType() != NVVM::LdStMatrixEltType::B8) {
return emitOpError("expected element type to be B8 for 16x8 matrix");
}
if (getLayout() != NVVM::MMALayout::col) {
return emitOpError("expected layout to be col for 16x8 matrix");
}
} else {
return emitOpError("expected shape to be 8x8 or 16x8");
}
return success();
}
static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
if (typeA == NVVM::WGMMATypes::tf32)
return 8;
if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
return 16;
if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
return 32;
if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
return 32;
if (typeA == NVVM::WGMMATypes::b1)
return 256;
return failure();
}
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
NVVM::WGMMATypes typeA,
NVVM::WGMMATypes typeB) {
switch (typeA) {
case NVVM::WGMMATypes::f16:
if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
typeB == NVVM::WGMMATypes::f16)
return success();
break;
case NVVM::WGMMATypes::tf32:
if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
return success();
break;
case NVVM::WGMMATypes::u8:
case NVVM::WGMMATypes::s8:
if (typeD == NVVM::WGMMATypes::s32 &&
(typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
return success();
break;
case NVVM::WGMMATypes::b1:
if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
return success();
break;
case NVVM::WGMMATypes::bf16:
if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
typeB == NVVM::WGMMATypes::bf16)
return success();
break;
case NVVM::WGMMATypes::e4m3:
case NVVM::WGMMATypes::e5m2:
if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
(typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
return success();
break;
case WGMMATypes::f32:
case WGMMATypes::s32:
llvm_unreachable("unsupported input types");
break;
}
return failure();
}
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
72, 80, 88, 96, 104, 112, 120, 128,
136, 144, 152, 160, 168, 176, 184, 192,
200, 208, 216, 224, 232, 240, 248, 256};
SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
80, 96, 112, 128, 144, 160,
176, 192, 208, 224, 240, 256};
switch (typeA) {
case WGMMATypes::f16:
case WGMMATypes::tf32:
case WGMMATypes::bf16:
case WGMMATypes::e4m3:
case WGMMATypes::e5m2:
if (llvm::is_contained(allowedN, sizeN))
return success();
break;
case WGMMATypes::u8:
case WGMMATypes::s8:
case WGMMATypes::b1:
if (llvm::is_contained(allowedNshort, sizeN))
return success();
break;
case WGMMATypes::f32:
case WGMMATypes::s32:
llvm_unreachable("unsupported input types");
break;
}
return failure();
}
LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
Value outValue = getResults();
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
if (!stype)
return emitOpError() << "expected results to be struct";
int outputSize = stype.getBody().size();
WGMMATypes typeD = getTypeD();
WGMMATypes typeA = getTypeA();
WGMMATypes typeB = getTypeB();
for (Type t : stype.getBody()) {
if (t != stype.getBody().front())
return emitOpError()
<< "all elements in struct must be same type but there is " << t;
}
if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
typeD != WGMMATypes::s32) {
return emitOpError() << "does not support the given output type "
<< NVVM::stringifyWGMMATypes(typeD);
}
if (typeD == WGMMATypes::s32 &&
(getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
}
if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
<< " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
<< NVVM::stringifyWGMMATypes(typeB)
<< ", it is not supported.";
}
// Check M
if (getShape().getM() != 64)
return emitOpError() << "shape 'm' must be 64";
// Check K
FailureOr<int> allowedK = getAllowedSizeK(typeA);
if (failed(allowedK) || allowedK.value() != getShape().getK())
return emitOpError() << "shape 'k' must be " << allowedK.value()
<< " for input type "
<< NVVM::stringifyWGMMATypes(typeA);
// Check N
if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
return emitOpError() << "has input type "
<< NVVM::stringifyWGMMATypes(typeA) << " n is set to "
<< getShape().getN() << ", it is not supported.";
}
// Check transpose (only available for f16/bf16)
// Matrices A should be stored in row-major and B in column-major.
// Only f16/bf16 matrices can be stored in either column-major or row-major
// by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
(getLayoutA() == mlir::NVVM::MMALayout::col ||
getLayoutB() == mlir::NVVM::MMALayout::row)) {
return emitOpError()
<< "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
<< " and layout_b = " << stringifyMMALayout(getLayoutB())
<< " for input types " << stringifyWGMMATypes(typeA) << " and "
<< stringifyWGMMATypes(typeB)
<< " requires transpose. However, this is only supported for: "
<< stringifyMMATypes(MMATypes::f16) << " and "
<< stringifyMMATypes(MMATypes::bf16);
}
// Check result registers
int expectedOutput = 0;
if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
expectedOutput = getShape().getN() / 2;
if (typeD == WGMMATypes::f16)
expectedOutput = getShape().getN() / 4;
if (outputSize != expectedOutput) {
return emitOpError() << "results " << expectedOutput
<< ", however output struct has " << outputSize
<< " elements";
}
// Check satfinite (only available for s32 accumulator)
if (typeD != WGMMATypes::s32 &&
getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
NVVM::MMAIntOverflow::satfinite) {
return emitOpError()
<< " `satfinite` can be only used with s32 accumulator, however "
"the current accumulator is "
<< NVVM::stringifyWGMMATypes(typeD);
}
return success();
}
std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
int expectedOutputRegisters = 0;
if (getTypeD() == WGMMATypes::f16)
expectedOutputRegisters = getShape().getN() / 4;
else
expectedOutputRegisters = getShape().getN() / 2;
std::string ptx;
llvm::raw_string_ostream ss(ptx);
ss << "{\n"
".reg .pred p;\n"
"setp.ne.b32 p, $"
<< ((expectedOutputRegisters * 2) + 2)
<< ", 0;\n"
"wgmma.mma_async.sync.aligned.m"
<< m << "n" << n << "k" << k << "." << outputTypeName << "."
<< stringifyWGMMATypes(getTypeA()) << "."
<< stringifyWGMMATypes(getTypeB());
if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
NVVM::MMAIntOverflow::satfinite)
ss << ".satfinite";
ss << " {";
int regCnt = 0;
for (; regCnt < expectedOutputRegisters; ++regCnt) {
ss << "$" << regCnt;
if (regCnt != expectedOutputRegisters - 1)
ss << ", ";
}
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
ss << " $" << (regCnt) << ","
<< " $" << (regCnt + 1) << ","
<< " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
// Don't add transpose parameters unless needed.
if (isF16) {
ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
}
ss << ";\n"
<< "}\n";
return ptx;
}
bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
if (getResults())
asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
if (getInouts())
asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
mlir::NVVM::PTXRegisterMod::Read});
if (getTypeD() != WGMMATypes::s32) {
asmValues.push_back(
{makeConstantI32(rewriter,
getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
mlir::NVVM::PTXRegisterMod::Read});
asmValues.push_back(
{makeConstantI32(rewriter,
getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
mlir::NVVM::PTXRegisterMod::Read});
}
if (isF16) {
asmValues.push_back(
{makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
mlir::NVVM::PTXRegisterMod::Read});
asmValues.push_back(
{makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
mlir::NVVM::PTXRegisterMod::Read});
}
return true; // Has manual mapping
}
LogicalResult NVVM::FenceSyncRestrictOp::verify() {
if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
getOrder() != NVVM::MemOrderKind::RELEASE)
return emitOpError("only acquire and release semantics are supported");
return success();
}
LogicalResult NVVM::FenceProxyOp::verify() {
if (getKind() == NVVM::ProxyKind::TENSORMAP)
return emitOpError() << "tensormap proxy is not a supported proxy kind";
if (getKind() == NVVM::ProxyKind::GENERIC)
return emitOpError() << "generic proxy not a supported proxy kind";
if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
return emitOpError() << "async_shared fence requires space attribute";
}
if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
return emitOpError() << "only async_shared fence can have space attribute";
}
return success();
}
LogicalResult NVVM::FenceProxyAcquireOp::verify() {
if (getFromProxy() != NVVM::ProxyKind::GENERIC)
return emitOpError("uni-directional proxies only support generic for "
"from_proxy attribute");
if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
return emitOpError("uni-directional proxies only support tensormap "
"for to_proxy attribute");
return success();
}
LogicalResult NVVM::FenceProxyReleaseOp::verify() {
if (getFromProxy() != NVVM::ProxyKind::GENERIC)
return emitOpError("uni-directional proxies only support generic for "
"from_proxy attribute");
if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
return emitOpError("uni-directional proxies only support tensormap "
"for to_proxy attribute");
return success();
}
LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
getOrder() != NVVM::MemOrderKind::RELEASE)
return emitOpError("only acquire and release semantics are supported");
if (getFromProxy() != NVVM::ProxyKind::GENERIC)
return emitOpError("only generic is support for from_proxy attribute");
if (getToProxy() != NVVM::ProxyKind::async)
return emitOpError("only async is supported for to_proxy attribute");
return success();
}
LogicalResult NVVM::SetMaxRegisterOp::verify() {
if (getRegCount() % 8)
return emitOpError("new register size must be multiple of 8");
if (getRegCount() < 24 || getRegCount() > 256)
return emitOpError("new register size must be in between 24 to 256");
return success();
}
LogicalResult NVVM::BarrierOp::verify() {
if (getNumberOfThreads() && !getBarrierId())
return emitOpError(
"barrier id is missing, it should be set between 0 to 15");
if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
return emitOpError("reduction are only available when id is 0");
if ((getReductionOp() && !getReductionPredicate()) ||
(!getReductionOp() && getReductionPredicate()))
return emitOpError("reduction predicate and reduction operation must be "
"specified together");
return success();
}
LogicalResult NVVM::Tcgen05CpOp::verify() {
auto mc = getMulticast();
using SH = Tcgen05CpShape;
using MC = Tcgen05CpMulticast;
switch (getShape()) {
case SH::SHAPE_128x256b:
case SH::SHAPE_128x128b:
case SH::SHAPE_4x256b:
if (mc != MC::NONE)
return emitError("Invalid multicast type for tcgen05.cp Op");
break;
case SH::SHAPE_64x128b:
if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
"warpx2_02_13 for tcgen05.cp Op");
break;
case SH::SHAPE_32x128b:
if (mc != MC::WARPX4)
return emitError(
"Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
break;
}
return success();
}
LogicalResult NVVM::MatchSyncOp::verify() {
if (getKind() == NVVM::MatchSyncKind::all) {
auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
if (!type || type.getBody().size() != 2 ||
!type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
return emitOpError("match.sync 'all' returns a two element struct with "
"first element as i32 and second element as i1");
}
} else {
if (!getType().isInteger(32)) {
return emitOpError("match.sync 'any' returns an i32");
}
}
return success();
}
LogicalResult NVVM::VoteSyncOp::verify() {
if (getKind() == NVVM::VoteSyncKind::ballot) {
if (!getType().isInteger(32)) {
return emitOpError("vote.sync 'ballot' returns an i32");
}
} else {
if (!getType().isInteger(1)) {
return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
}
}
return success();
}
LogicalResult NVVM::PrefetchOp::verify() {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
if (getTensormap() && cacheLevel)
return emitOpError("cannot specify both tensormap and cache level");
if (getTensormap()) {
if (addressSpace != MemSpace::Generic &&
addressSpace != MemSpace::Constant) {
return emitOpError(
"prefetch tensormap requires a generic or constant pointer");
}
if (evictPriority) {
return emitOpError(
"prefetch tensormap does not support eviction priority");
}
if (getInParamSpace() && addressSpace != MemSpace::Generic) {
return emitOpError(
"in_param_space can only be specified for a generic pointer");
}
} else if (cacheLevel) {
if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
addressSpace != MemSpace::Local) {
return emitOpError("prefetch to cache level requires a generic, global, "
"or local pointer");
}
if (getUniform()) {
if (*cacheLevel != CacheLevel::L1) {
return emitOpError(
"unsupported cache level, the only supported uniform "
"cache level is L1");
}
if (addressSpace != MemSpace::Generic) {
return emitOpError(
"prefetch to uniform cache requires a generic pointer");
}
}
if (evictPriority) {
if (*cacheLevel != CacheLevel::L2)
return emitOpError(
"cache eviction priority supported only for cache level L2");
if (addressSpace != MemSpace::Global)
return emitOpError("cache eviction priority requires a global pointer");
if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
return emitOpError(
"unsupported cache eviction priority, only evict_last and "
"evict_normal are supported");
}
if (getPredicate())
return emitOpError("predicate supported only on prefetch tensormap");
} else {
return emitOpError(
"requires specification of either cache level or tensormap");
}
return success();
}
LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
switch (getQueryType()) {
case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
if (!getType().isInteger(1))
return emitOpError("is_canceled query type returns an i1");
break;
case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
if (!getType().isInteger(32)) {
return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
"get_first_cta_id_z query types return an i32");
}
break;
}
return success();
}
LogicalResult NVVM::ReduxOp::verify() {
mlir::Type reduxType = getType();
if (!reduxType.isF32()) {
if (getAbs())
return emitOpError("abs attribute is supported only for f32 type");
if (getNan())
return emitOpError("nan attribute is supported only for f32 type");
}
NVVM::ReduxKind kind = getKind();
switch (kind) {
case NVVM::ReduxKind::ADD:
case NVVM::ReduxKind::AND:
case NVVM::ReduxKind::OR:
case NVVM::ReduxKind::XOR:
case NVVM::ReduxKind::MAX:
case NVVM::ReduxKind::MIN:
case NVVM::ReduxKind::UMAX:
case NVVM::ReduxKind::UMIN:
if (!reduxType.isInteger(32))
return emitOpError("'")
<< stringifyEnum(kind) << "' redux kind unsupported with "
<< reduxType << " type. Only supported type is 'i32'.";
break;
case NVVM::ReduxKind::FMIN:
case NVVM::ReduxKind::FMAX:
if (!reduxType.isF32())
return emitOpError("'")
<< stringifyEnum(kind) << "' redux kind unsupported with "
<< reduxType << " type. Only supported type is 'f32'.";
break;
}
return success();
}
LogicalResult NVVM::TensormapReplaceOp::verify() {
auto ord = getOrd();
Value newVal = getNewValue();
auto newValAttr = getNewValueAttr();
auto fieldName = stringifyEnum(getField());
if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
NVVM::TensormapField::GLOBAL_DIM,
NVVM::TensormapField::GLOBAL_STRIDE,
NVVM::TensormapField::ELEMENT_STRIDE},
getField()))
return emitOpError("ordinal is not supported for ")
<< fieldName << " field";
auto invalidNewVal = [&](llvm::Twine type) -> std::string {
return llvm::Twine("new_value must be specified and must be an " + type +
" for " + llvm::Twine(fieldName) + " field")
.str();
};
auto invalidNewValAttr = [&]() -> std::string {
return (llvm::Twine(
"new_value_attr must be specified and must be a valid ") +
llvm::Twine(fieldName) + " attribute for " + fieldName + " field")
.str();
};
switch (getField()) {
case NVVM::TensormapField::GLOBAL_ADDRESS:
if (!(newVal && newVal.getType().isInteger(64)))
return emitOpError(invalidNewVal("i64"));
break;
case NVVM::TensormapField::RANK:
if (!(newVal && newVal.getType().isInteger(32)))
return emitOpError(invalidNewVal("i32"));
break;
case NVVM::TensormapField::GLOBAL_STRIDE:
if (!ord)
return emitOpError("ordinal is required for global_stride field");
if (!(newVal && newVal.getType().isInteger(64)))
return emitOpError(invalidNewVal("i64"));
break;
case NVVM::TensormapField::BOX_DIM:
case NVVM::TensormapField::GLOBAL_DIM:
case NVVM::TensormapField::ELEMENT_STRIDE:
if (!ord)
return emitOpError("ordinal is required for ")
<< stringifyEnum(getField()) << " field";
if (!(newVal && newVal.getType().isInteger(32)))
return emitOpError(invalidNewVal("i32"));
break;
case NVVM::TensormapField::ELEMTYPE:
if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
return emitOpError(invalidNewValAttr());
break;
case NVVM::TensormapField::INTERLEAVE_LAYOUT:
if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
return emitOpError(invalidNewValAttr());
break;
case NVVM::TensormapField::SWIZZLE_MODE:
if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
return emitOpError(invalidNewValAttr());
break;
case NVVM::TensormapField::SWIZZLE_ATOMICITY:
if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
return emitOpError(invalidNewValAttr());
break;
case NVVM::TensormapField::FILL_MODE:
if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
return emitOpError(invalidNewValAttr());
break;
}
return success();
}
/// Packs the given `field` into the `result`.
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
static llvm::Value *
packValInto64Bits(llvm::IRBuilderBase &builder,
llvm::Value *result, // the `result` (unset bits are zero)
llvm::Value *field, // `field` to pack into `result`
unsigned sizeInBits, // Size of `field` in bits
unsigned start) { // Starting bit within `result`
field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
if (mask != 0xffffffffu)
field = builder.CreateAnd(field, builder.getInt32(mask));
field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
field = builder.CreateShl(field, start);
return builder.CreateOr(result, field);
}
void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
llvm::Value *smemDesc = builder.getInt64(0);
smemDesc = packValInto64Bits(builder, smemDesc,
mt.lookupValue(thisOp.getStartAddr()), 14, 0);
smemDesc = packValInto64Bits(
builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
smemDesc = packValInto64Bits(
builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
smemDesc = packValInto64Bits(builder, smemDesc,
mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
smemDesc = packValInto64Bits(
builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
smemDesc = packValInto64Bits(builder, smemDesc,
mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
mt.mapValue(thisOp.getRes()) = smemDesc;
}
//===----------------------------------------------------------------------===//
// getPtx methods
//===----------------------------------------------------------------------===//
std::string NVVM::MBarrierInitOp::getPtx() {
bool isShared = isPtrInSharedCTASpace(getAddr());
return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
: std::string("mbarrier.init.b64 [%0], %1;");
}
std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
bool isShared = isPtrInSharedCTASpace(getAddr());
return isShared
? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
: std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
}
std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
bool isShared = isPtrInSharedCTASpace(getAddr());
llvm::StringRef space = isShared ? ".shared" : "";
return llvm::formatv("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
"}",
space);
}
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::BarrierOp>(op);
llvm::Value *barrierId = thisOp.getBarrierId()
? mt.lookupValue(thisOp.getBarrierId())
: builder.getInt32(0);
llvm::Intrinsic::ID id;
llvm::SmallVector<llvm::Value *> args = {barrierId};
if (thisOp.getNumberOfThreads()) {
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
} else if (thisOp.getReductionOp()) {
switch (*thisOp.getReductionOp()) {
case NVVM::BarrierReduction::AND:
id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
break;
case NVVM::BarrierReduction::OR:
id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
break;
case NVVM::BarrierReduction::POPC:
id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
break;
}
args.push_back(builder.CreateICmpNE(
mt.lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
} else {
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
}
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair
PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::PMEventOp>(op);
llvm::Type *i16Ty = llvm::Type::getInt16Ty(mt.getLLVMContext());
// With event-id, mask is generated as (1 << event-id)
llvm::Value *maskVal;
if (auto eventAttr = thisOp.getEventIdAttr()) {
uint16_t mask = static_cast<uint16_t>(1u << eventAttr.getInt());
maskVal = llvm::ConstantInt::get(i16Ty, mask);
} else {
maskVal =
llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
}
return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
}
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
: llvm::Intrinsic::nvvm_mbarrier_init;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getAddr()));
args.push_back(mt.lookupValue(thisOp.getCount()));
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
llvm::Intrinsic::ID id = isShared
? llvm::Intrinsic::nvvm_mbarrier_inval_shared
: llvm::Intrinsic::nvvm_mbarrier_inval;
return {id, {mt.lookupValue(thisOp.getAddr())}};
}
mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
// bit-0: Space
// bit-1: Scope
size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getAddr()));
args.push_back(mt.lookupValue(thisOp.getTxcount()));
return {IDs[index], std::move(args)};
}
mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
// bit-0: Space
// bit-1: Scope
size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getAddr()));
args.push_back(mt.lookupValue(thisOp.getTxcount()));
return {IDs[index], std::move(args)};
}
mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
// bit-0: Space
// bit-1: Scope
size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
llvm::Intrinsic::
nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
// Tidy-up the Intrinsic Args
bool needCast = isPtrInGenericSpace(thisOp.getAddr());
llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
if (needCast)
mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
// We have the most basic mbarrier.arrive supported on sm_80.
// It supports: Space=cta, scope=cta, No relaxed, No explicit count.
// So, only for this combination use the legacy intrinsic.
bool hasCount = static_cast<bool>(thisOp.getCount());
if (!hasCount &&
(id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
// When count is not explicitly specified, the default is 1.
llvm::LLVMContext &ctx = mt.getLLVMContext();
llvm::Value *count =
hasCount ? mt.lookupValue(thisOp.getCount())
: llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
return {id, {mbar, count}};
}
mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
// bit-0: Space
// bit-1: Scope
size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
llvm::Intrinsic::
nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
llvm::Intrinsic::
nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
llvm::Intrinsic::
nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
// Tidy-up the Intrinsic Args
bool needCast = isPtrInGenericSpace(thisOp.getAddr());
llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
if (needCast)
mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
// When count is not explicitly specified, the default is 1.
llvm::LLVMContext &ctx = mt.getLLVMContext();
bool hasCount = static_cast<bool>(thisOp.getCount());
llvm::Value *count =
hasCount ? mt.lookupValue(thisOp.getCount())
: llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
return {id, {mbar, count}};
}
bool MBarrierArriveExpectTxOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
// Add all the operands but not the attrs to the asmValues list.
// The attrs here are used to generate the right variants for
// intrinsics-lowering. So, we ignore them while generating inline-PTX.
for (auto val : getOperands())
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
return false;
}
mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
// bit-0: Space
// bit-1: Scope
size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
// clang-format off
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
// clang-format on
auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
// Tidy-up the Intrinsic Args
llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
bool needCast = isPtrInGenericSpace(thisOp.getAddr());
if (needCast)
mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
return {id, {mbar, txcount}};
}
mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
// bit-0: Space
// bit-1: Scope
size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
// clang-format off
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
// clang-format on
auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
// Tidy-up the Intrinsic Args
llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
bool needCast = isPtrInGenericSpace(thisOp.getAddr());
if (needCast)
mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
return {id, {mbar, txcount}};
}
mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
llvm::Intrinsic::ID id =
isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
: llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getAddr()));
args.push_back(mt.lookupValue(thisOp.getCount()));
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
llvm::Intrinsic::ID id =
isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
: llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getAddr()));
args.push_back(mt.lookupValue(thisOp.getCount()));
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
// bit-0: isPhaseParity
// bit-1: Scope
size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
// clang-format off
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
// clang-format on
auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
// Tidy-up the Intrinsic Args
llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
bool needCast = isPtrInGenericSpace(thisOp.getAddr());
if (needCast)
mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
return {id, {mbar, input}};
}
mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
bool hasTicks = static_cast<bool>(thisOp.getTicks());
// bit-0: isPhaseParity
// bit-1: Scope
// bit-2: hasTicks
size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
(isPhaseParity ? 1 : 0);
// clang-format off
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
// clang-format on
auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
// Tidy-up the mbarrier pointer
llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
bool needCast = isPtrInGenericSpace(thisOp.getAddr());
if (needCast)
mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mbar);
args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
if (hasTicks)
args.push_back(mt.lookupValue(thisOp.getTicks()));
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
llvm::Intrinsic::ID id;
if (thisOp.getNoinc()) {
id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
: llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
} else {
id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
: llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
}
return {id, {mt.lookupValue(thisOp.getAddr())}};
}
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
llvm::Intrinsic::ID
CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::SmallVector<llvm::Value *> &args) {
llvm::Intrinsic::ID id;
auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
switch (cpAsyncOp.getSize()) {
case 4:
id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
break;
case 8:
id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
break;
case 16:
id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
: GET_CP_ASYNC_ID(ca, 16, hasCpSize);
break;
default:
llvm_unreachable("Invalid copy size in CpAsyncOp.");
}
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
if (hasCpSize)
args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
return id;
}
mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
llvm::SmallVector<llvm::Value *> args;
llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));
mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Unused =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
args.push_back(builder.getInt1(hasCacheHint));
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
llvm::SmallVector<llvm::Value *> args;
// Fill the Intrinsic Args: dst, mbar, src, size.
args.push_back(mt.lookupValue(thisOp.getDstMem()));
args.push_back(mt.lookupValue(thisOp.getMbar()));
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));
// Multicast mask for shared::cluster only, if available.
mlir::Value multicastMask = thisOp.getMulticastMask();
const bool hasMulticastMask = static_cast<bool>(multicastMask);
const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
if (!isSharedCTA) {
llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
: i16Unused);
}
// Cache hint, if available.
mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
// Flag arguments for multicast and cachehint.
if (!isSharedCTA)
args.push_back(builder.getInt1(hasMulticastMask));
args.push_back(builder.getInt1(hasCacheHint));
llvm::Intrinsic::ID id =
isSharedCTA
? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
: llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
llvm::SmallVector<llvm::Value *> args;
llvm::Intrinsic::ID id =
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getDstMem()));
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));
mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Unused =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
args.push_back(builder.getInt1(hasCacheHint));
// Choose the bytemask variant
if (mlir::Value byteMask = thisOp.getByteMask()) {
args.push_back(mt.lookupValue(byteMask));
id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
}
return {id, std::move(args)};
}
bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
// Add all the operands but not the attrs to the asmValues list.
// The attrs here are used to generate the right variants for
// intrinsics-lowering. So, we ignore them while generating inline-PTX.
for (auto val : getOperands())
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
return false;
}
mlir::NVVM::IDArgPair
CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
const bool isCTAOnly = thisOp.getIsCTAOnly();
llvm::SmallVector<llvm::Value *> args;
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getDstMem()));
args.push_back(mt.lookupValue(thisOp.getMbar()));
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
// Coordinates and im2col-offsets
for (mlir::Value v : thisOp.getCoordinates())
args.push_back(mt.lookupValue(v));
for (mlir::Value v : thisOp.getIm2colOffsets())
args.push_back(mt.lookupValue(v));
// MulticastMask, if available
mlir::Value mcMask = thisOp.getMulticastMask();
const bool hasMC = static_cast<bool>(mcMask);
llvm::Value *i16Zero =
llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
// CacheHint, if available
mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Zero =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
// Flag argument CTAGroup
// CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
// Hence, the +1 to getGroup().
const int32_t val =
thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
llvm::Value *cg =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
if (!isCTAOnly) {
// For shared::cluster, all the arguments that we build are applicable.
args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
args.push_back(builder.getInt1(hasMC));
args.push_back(builder.getInt1(hasCacheHint));
args.push_back(cg);
} else {
// For shared::cta, only cache-hint is applicable.
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
args.push_back(builder.getInt1(hasCacheHint));
}
constexpr size_t numDims = 5; // 1D to 5D
constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
using TableTy = std::array<rowTy, numModes>;
static constexpr TableTy IDTable{
{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
{notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
static constexpr TableTy IDTableCTA{
{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
{notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
static_assert(
(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
(getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
"TMALoadModes must match number of rows in IDTable and IDTableCTA");
size_t mode = static_cast<size_t>(thisOp.getMode());
size_t dim = thisOp.getCoordinates().size();
auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
assert(id != notIntrinsic &&
"Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
llvm::SmallVector<llvm::Value *> args;
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
for (auto v : thisOp.getCoordinates())
args.push_back(mt.lookupValue(v));
for (auto v : thisOp.getIm2colOffsets())
args.push_back(mt.lookupValue(v));
mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Unused =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
args.push_back(builder.getInt1(hasCacheHint));
const unsigned NI = llvm::Intrinsic::not_intrinsic;
static constexpr llvm::Intrinsic::ID IDTable[][6] = {
{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
{NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
{NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
{NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
{NI, NI, NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
"TMALoadModes must match number of rows in IDTable");
size_t mode = static_cast<size_t>(thisOp.getMode());
size_t dim = thisOp.getCoordinates().size();
llvm::Intrinsic::ID id = IDTable[mode][dim];
if (id == llvm::Intrinsic::not_intrinsic)
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
return {id, std::move(args)};
}
mlir::NVVM::IDArgPair
CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
llvm::SmallVector<llvm::Value *> args;
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
for (auto v : thisOp.getCoordinates())
args.push_back(mt.lookupValue(v));
mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Unused =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
args.push_back(builder.getInt1(hasCacheHint));
const unsigned NI = llvm::Intrinsic::not_intrinsic;
static constexpr llvm::Intrinsic::ID IDTable[][6] = {
{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
{NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
{NI, NI, NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
"TMAStoreModes must match number of rows in IDTable");
size_t mode = static_cast<size_t>(thisOp.getMode());
size_t dim = thisOp.getCoordinates().size();
llvm::Intrinsic::ID id = IDTable[mode][dim];
if (id == llvm::Intrinsic::not_intrinsic)
llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
return {id, std::move(args)};
}
NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
llvm::LLVMContext &ctx = mt.getLLVMContext();
llvm::SmallVector<llvm::Value *> args;
// Arguments to the intrinsic:
// shared_mem_ptr, tmaDesc, tensorDims
// cache_hint(if applicable) and flag(boolean)
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
for (Value v : thisOp.getCoordinates())
args.push_back(mt.lookupValue(v));
mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64ZeroValue =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
args.push_back(builder.getInt1(hasCacheHint));
const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
constexpr unsigned numLayouts = 2; // TILE, IM2COL
constexpr unsigned maxDim = 5; // 1D to 5D
using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
using layoutTable = std::array<row, numLayouts>;
using fullTable = std::array<layoutTable, numRedKinds>;
static constexpr fullTable IDTable{
{// RedTy::ADD
{{{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
{{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
// RedTy::MIN
{{{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
{{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
// RedTy::MAX
{{{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
{{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
// RedTy::INC
{{{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
{{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
// RedTy::DEC
{{{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
{{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
// RedTy::AND
{{{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
{{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
// RedTy::OR
{{{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
{{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
// RedTy::XOR
{{{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
{{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
llvm::Intrinsic::
nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
"TMAReduxKinds must match number of rows in IDTable");
size_t redKind = static_cast<size_t>(thisOp.getRedKind());
size_t mode = static_cast<size_t>(thisOp.getMode());
size_t dim = thisOp.getCoordinates().size();
assert(redKind < IDTable.size() &&
"Invalid redKind for CpAsyncBulkTensorReduceOp");
assert(mode < IDTable[redKind].size() &&
"Invalid mode for CpAsyncBulkTensorReduceOp");
assert(dim < IDTable[redKind][mode].size() &&
"Invalid dim for CpAsyncBulkTensorReduceOp");
llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
assert(intrinsicID != notIntrinsic &&
"Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
return {intrinsicID, std::move(args)};
}
#define _none
#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
: llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
: CVT_F2TF32_ID_IMPL(rnd, relu, )
llvm::Intrinsic::ID
ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
using RndMode = NVVM::FPRoundingMode;
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
switch (rnd) {
case RndMode::RN:
return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
case RndMode::RZ:
return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
case RndMode::RNA:
return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
default:
llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
}
}
NVVM::IDArgPair
ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getA()));
args.push_back(mt.lookupValue(op.getB()));
bool hasRelu = op.getRelu();
llvm::Intrinsic::ID intId =
hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
return {intId, std::move(args)};
}
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
bool hasRelu) {
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
})
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
llvm::Intrinsic::ID
ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
})
.Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
if (hasRoundingModeRZ)
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
else if (hasRoundingModeRP)
return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
bool hasRelu) {
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
llvm::Intrinsic::ID
ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
switch (rnd) {
case NVVM::FPRoundingMode::RZ:
return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
case NVVM::FPRoundingMode::RP:
return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
default:
llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
}
}
NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
bool hasRelu = curOp.getRelu();
llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
})
.Case<Float8E5M2Type>([&](Float8E5M2Type type) {
return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
})
.Default([](mlir::Type type) {
llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
return llvm::Intrinsic::not_intrinsic;
});
llvm::Value *packedI16 =
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));
return {intId, {packedI16}};
}
NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
llvm::Value *packedI16 =
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));
return {intId, {packedI16}};
}
NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
bool hasRelu = curOp.getRelu();
llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
})
.Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
})
.Default([](mlir::Type type) {
llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
return llvm::Intrinsic::not_intrinsic;
});
llvm::Value *packedI16 =
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));
return {intId, {packedI16}};
}
NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
bool hasRelu = curOp.getRelu();
llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
})
.Default([](mlir::Type type) {
llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
return llvm::Intrinsic::not_intrinsic;
});
llvm::Value *extendedI16 =
builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));
return {intId, {extendedI16}};
}
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
llvm::SmallVector<llvm::Value *> &args) {
auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
.getAddressSpace();
bool isShared = as == NVVMMemorySpace::Shared;
bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
llvm::Intrinsic::ID id;
if (isShared) {
id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
: llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
} else {
id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
: llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
}
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(curOp.getAddr()));
args.push_back(mt.lookupValue(curOp.getNCols()));
return id;
}
llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt,
llvm::SmallVector<llvm::Value *> &args) {
auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
: llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(curOp.getTaddr()));
args.push_back(mt.lookupValue(curOp.getNCols()));
return id;
}
#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
: llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
: TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
llvm::Intrinsic::ID
Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
llvm::SmallVector<llvm::Value *> &args) {
auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
.getAddressSpace();
bool isShared = as == NVVMMemorySpace::Shared;
bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
llvm::Intrinsic::ID id =
is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
: GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
// Fill the Intrinsic Args
args.push_back(mt.lookupValue(curOp.getAddr()));
if (hasMulticast)
args.push_back(mt.lookupValue(curOp.getMulticastMask()));
return id;
}
#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
[&]() -> auto { \
if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
NVVM::IDArgPair
ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
static constexpr llvm::Intrinsic::ID rndRNIds[] = {
llvm::Intrinsic::nvvm_ff2f16x2_rn,
llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
};
static constexpr llvm::Intrinsic::ID rndRZIds[] = {
llvm::Intrinsic::nvvm_ff2f16x2_rz,
llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
};
static constexpr llvm::Intrinsic::ID rndRSIds[] = {
llvm::Intrinsic::nvvm_ff2f16x2_rs,
llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
};
unsigned hasRelu = op.getRelu() ? 1 : 0;
unsigned hasSatFinite =
(op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
// idx: bit-0 - relu
// bit-1 - satfinite
unsigned idx = (hasSatFinite << 1) | hasRelu;
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getSrcHi()));
args.push_back(mt.lookupValue(op.getSrcLo()));
if (op.getRandomBits())
args.push_back(mt.lookupValue(op.getRandomBits()));
switch (op.getRnd()) {
case FPRoundingMode::RN:
return {rndRNIds[idx], std::move(args)};
case FPRoundingMode::RZ:
return {rndRZIds[idx], std::move(args)};
case FPRoundingMode::RS:
return {rndRSIds[idx], std::move(args)};
default:
llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
}
}
NVVM::IDArgPair
ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
static constexpr llvm::Intrinsic::ID rndRNIds[] = {
llvm::Intrinsic::nvvm_ff2bf16x2_rn,
llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
};
static constexpr llvm::Intrinsic::ID rndRZIds[] = {
llvm::Intrinsic::nvvm_ff2bf16x2_rz,
llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
};
static constexpr llvm::Intrinsic::ID rndRSIds[] = {
llvm::Intrinsic::nvvm_ff2bf16x2_rs,
llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
};
unsigned hasRelu = op.getRelu() ? 1 : 0;
unsigned hasSatFinite =
(op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
// idx: bit-0 - relu
// bit-1 - satfinite
unsigned idx = (hasSatFinite << 1) | hasRelu;
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getSrcHi()));
args.push_back(mt.lookupValue(op.getSrcLo()));
if (op.getRandomBits())
args.push_back(mt.lookupValue(op.getRandomBits()));
switch (op.getRnd()) {
case FPRoundingMode::RN:
return {rndRNIds[idx], std::move(args)};
case FPRoundingMode::RZ:
return {rndRZIds[idx], std::move(args)};
case FPRoundingMode::RS:
return {rndRSIds[idx], std::move(args)};
default:
llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
}
}
llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
mlir::Type dstTy = getDstTy();
bool hasRelu = getRelu();
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
return llvm::Intrinsic::not_intrinsic;
});
}
llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
mlir::Type dstTy = getDstTy();
bool hasRelu = getRelu();
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
})
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
return llvm::Intrinsic::not_intrinsic;
});
}
llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
mlir::Type dstTy = getDstTy();
bool hasRelu = getRelu();
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
: llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
return llvm::Intrinsic::not_intrinsic;
});
}
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
auto srcFmt = curOp.getSrcFormat();
auto mc = curOp.getMulticast();
switch (curOp.getShape()) {
case Tcgen05CpShape::SHAPE_128x256b:
return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
case Tcgen05CpShape::SHAPE_128x128b:
return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
case Tcgen05CpShape::SHAPE_4x256b:
return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
case Tcgen05CpShape::SHAPE_32x128b:
return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
case Tcgen05CpShape::SHAPE_64x128b:
return (mc == Tcgen05CpMulticast::WARPX2_01_23)
? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
: GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
}
llvm_unreachable("Invalid shape in tcgen05 cp Op");
}
// Returns the valid vector length for a given shape and vector length, the
// function models the table mentioned in the tcgen05.{ld, st} Op description
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
unsigned vecLen) {
if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
return vecLen >= 2;
if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
return vecLen >= 4;
return true;
}
LogicalResult Tcgen05LdOp::verify() {
LogicalResult result = success();
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
result = emitError("shape 16x32bx2 requires offset argument");
if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
result = emitError("offset argument is only supported for shape 16x32bx2");
auto resTy = getRes().getType();
unsigned resLen = isa<VectorType>(resTy)
? llvm::cast<VectorType>(resTy).getNumElements()
: 1;
if (!isValidVectorLength(getShape(), resLen))
result = emitError(llvm::formatv("invalid result type length {0} for shape "
"{1} in tcgen05.ld Op",
resLen, stringifyEnum(getShape())));
return result;
}
LogicalResult Tcgen05StOp::verify() {
LogicalResult result = success();
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
result = emitError("shape 16x32bx2 requires offset argument");
auto valTy = getVal().getType();
unsigned valLen = isa<VectorType>(valTy)
? llvm::cast<VectorType>(valTy).getNumElements()
: 1;
if (!isValidVectorLength(getShape(), valLen))
result = emitError(llvm::formatv("invalid input length {0} for shape "
"{1} in tcgen05.st Op",
valLen, stringifyEnum(getShape())));
return result;
}
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
/// have ConstantRangeAttr.
static void nvvmInferResultRanges(Operation *op, Value result,
ArrayRef<::mlir::ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
rangeAttr.getLower(), rangeAttr.getUpper()});
} else {
setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
}
}
/// Verify the range attribute satisfies LLVM ConstantRange constructor
/// requirements for NVVM SpecialRangeableRegisterOp.
static LogicalResult
verifyConstantRangeAttr(Operation *op,
std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
if (!rangeAttr)
return success();
const llvm::APInt &lower = rangeAttr->getLower();
const llvm::APInt &upper = rangeAttr->getUpper();
// Check LLVM ConstantRange constructor condition
if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
unsigned bitWidth = lower.getBitWidth();
llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
return op->emitOpError(
"invalid range attribute: Lower == Upper, but they aren't min (")
<< llvm::toString(minVal, 10, false) << ") or max ("
<< llvm::toString(maxVal, 10, false)
<< ") value! This is an invalid constant range.";
}
return success();
}
static llvm::Value *getAsPackedI32(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
return builder.CreateBitCast(arg,
llvm::Type::getInt32Ty(builder.getContext()));
}
NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
args.push_back(mt.lookupValue(curOp.getC()));
bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
unsigned type = (isASigned << 1) | isBSigned;
const llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_idp4a_u_u,
llvm::Intrinsic::nvvm_idp4a_u_s,
llvm::Intrinsic::nvvm_idp4a_s_u,
llvm::Intrinsic::nvvm_idp4a_s_s,
};
return {ids[type], args};
}
NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
args.push_back(builder.getInt1(curOp.getBHi()));
args.push_back(mt.lookupValue(curOp.getC()));
bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
unsigned type = (isASigned << 1) | isBSigned;
const llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_idp2a_u_u,
llvm::Intrinsic::nvvm_idp2a_u_s,
llvm::Intrinsic::nvvm_idp2a_s_u,
llvm::Intrinsic::nvvm_idp2a_s_s,
};
return {ids[type], args};
}
static llvm::Value *getParamCastedAddr(llvm::Value *addr,
llvm::IRBuilderBase &builder) {
return builder.CreateAddrSpaceCast(
addr,
llvm::PointerType::get(builder.getContext(),
llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
}
NVVM::IDArgPair
PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;
std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
std::optional<NVVM::CacheEvictionPriority> evictPriority =
op.getEvictPriority();
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
.getAddressSpace();
llvm::SmallVector<llvm::Value *> args;
llvm::Value *addr = mt.lookupValue(op.getAddr());
args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
: addr);
if (op.getTensormap())
return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
assert(cacheLevel && "expected cache level for non-tensormap prefetch");
if (op.getUniform() && *cacheLevel == CacheLevel::L1)
return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
if (evictPriority && *cacheLevel == CacheLevel::L2) {
switch (*evictPriority) {
case NVVM::CacheEvictionPriority::EvictLast:
return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
case NVVM::CacheEvictionPriority::EvictNormal:
return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
default:
llvm_unreachable("Invalid cache eviction priority");
}
}
switch (static_cast<MemSpace>(addressSpace)) {
case MemSpace::Generic:
return *cacheLevel == CacheLevel::L1
? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
: NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
case MemSpace::Global:
return *cacheLevel == CacheLevel::L1
? NVVM::IDArgPair(
{llvm::Intrinsic::nvvm_prefetch_global_L1, args})
: NVVM::IDArgPair(
{llvm::Intrinsic::nvvm_prefetch_global_L2, args});
case MemSpace::Local:
return *cacheLevel == CacheLevel::L1
? NVVM::IDArgPair(
{llvm::Intrinsic::nvvm_prefetch_local_L1, args})
: NVVM::IDArgPair(
{llvm::Intrinsic::nvvm_prefetch_local_L2, args});
default:
llvm_unreachable("Invalid pointer address space");
}
}
bool NVVM::InlinePtxOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
for (auto arg : getReadWriteArgs())
asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
for (auto arg : getResults())
asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
for (auto arg : getReadOnlyArgs())
asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
if (getPredicate())
asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
return false; // No manual mapping needed
}
NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(curOp.getSmemAddress()));
args.push_back(mt.lookupValue(curOp.getMbarrier()));
llvm::Intrinsic::ID intrinsicID =
curOp.getMulticast()
? llvm::Intrinsic::
nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
: llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
return {intrinsicID, args};
}
NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
llvm::Intrinsic::ID intrinsicID;
switch (curOp.getQueryType()) {
case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
intrinsicID =
llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
break;
case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
intrinsicID = llvm::Intrinsic::
nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
break;
case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
intrinsicID = llvm::Intrinsic::
nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
break;
case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
intrinsicID = llvm::Intrinsic::
nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
break;
}
return {intrinsicID, args};
}
mlir::NVVM::IDArgPair
PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::PermuteOp>(op);
NVVM::PermuteMode mode = thisOp.getMode();
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
llvm::Intrinsic::nvvm_prmt_rc16};
unsigned modeIndex = static_cast<unsigned>(mode);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getLo()));
// Only first 3 modes (Default, f4e, b4e) need the hi operand.
if (modeIndex < 3)
args.push_back(mt.lookupValue(thisOp.getHi()));
args.push_back(mt.lookupValue(thisOp.getSelector()));
return {IDs[modeIndex], args};
}
mlir::NVVM::IDArgPair TensormapReplaceOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getAddr()));
if (thisOp.getOrd())
args.push_back(builder.getInt32(thisOp.getOrd().value()));
if (thisOp.getNewValue())
args.push_back(mt.lookupValue(thisOp.getNewValue()));
if (auto attr = thisOp.getNewValueAttr()) {
auto val =
llvm::TypeSwitch<mlir::Attribute, unsigned>(*attr)
.Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
TensormapFillModeAttr>([](auto attr) {
return static_cast<unsigned>(attr.getValue());
})
.Default([](auto attr) {
llvm_unreachable("Invalid attribute type");
return 0;
});
args.push_back(builder.getInt32(val));
}
static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_tensormap_replace_global_address,
llvm::Intrinsic::nvvm_tensormap_replace_rank,
llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
};
unsigned fieldIndex = static_cast<unsigned>(thisOp.getField());
return {IDs[fieldIndex], args};
}
//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma functions
//===----------------------------------------------------------------------===//
mlir::NVVM::IDArgPair
Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getMatrixD()));
llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
const bool isATensor = isa<llvm::PointerType>(A->getType());
args.push_back(A);
args.push_back(mt.lookupValue(thisOp.getMatrixB()));
args.push_back(mt.lookupValue(thisOp.getIdesc()));
args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
using CtaGroupArray = std::array<EnableAShiftArray, 2>;
using IsATensorArray = std::array<CtaGroupArray, 2>;
using HasScaleInputDArray = std::array<IsATensorArray, 2>;
using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
// [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
{ // without diable output lane
{{// without scale input D
{{
// shared
{{// cg1
{llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
// cg2
{llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
{{// tensor
{
// cg1
llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
},
{
// cg2
llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
}}},
}},
// with scale input D
{{ // shared
{{// cg1
{llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
// cg2
{llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
{{// tensor
{
// cg1
llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
},
{
// cg2
llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
}}}}}}},
// with disable output lane
{{ // without scale input D
{{ // shared
{{// cg1
{llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
notIntrinsic},
// cg2
{llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
notIntrinsic}}},
{{// cg1
{
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
},
// cg2
{
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
}}}}},
// with scale input D
{{ // shared
{{// cg1
{llvm::Intrinsic::
nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
notIntrinsic},
// cg2
{llvm::Intrinsic::
nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
notIntrinsic}}},
// tensor
{{// cg1
{llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
// cg2
{
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
}}}}}}}}};
llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
bool hasScaleInputD = ScaleInputD != nullptr;
llvm::Value *DisableOutputLane =
mt.lookupValue(thisOp.getDisableOutputLane());
bool hasDisableOutputLane = DisableOutputLane != nullptr;
const unsigned ctaGroup =
static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
llvm::Intrinsic::ID ID =
tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
[ctaGroup - 1][thisOp.getAShift()];
assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
if (hasScaleInputD)
args.push_back(ScaleInputD);
if (hasDisableOutputLane)
args.push_back(DisableOutputLane);
args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
if (!hasDisableOutputLane)
args.push_back(builder.getInt32(ctaGroup));
args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
return {ID, args};
}
static LogicalResult
verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
NVVM::CTAGroupKind ctaGroup, bool hasAShift,
NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
if (disableOutputLane) {
mlir::VectorType disableOutputLaneType =
cast<mlir::VectorType>(disableOutputLane.getType());
if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
disableOutputLaneType.getNumElements() != 4) ||
(ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
disableOutputLaneType.getNumElements() != 8))
return emitError(loc) << "Disable Output Lane of length "
<< disableOutputLaneType.getNumElements()
<< " is incompatible with CtaGroupAttr";
}
if (hasAShift && !isATensor)
return emitError(
loc, "A-shift can be applied only when matrix A is in tensor memory");
if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
collectorOp == Tcgen05MMACollectorOp::USE))
return emitError(
loc, "Cannot use collector buffer operation fill or use with ashift");
return success();
}
LogicalResult Tcgen05MMAOp::verify() {
return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
getDisableOutputLane(), getCtaGroup(), getAShift(),
getCollectorOp(), getLoc());
}
//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma.sp functions
//===----------------------------------------------------------------------===//
mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getMatrixD()));
llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
bool isATensor = isa<llvm::PointerType>(A->getType());
args.push_back(A);
args.push_back(mt.lookupValue(thisOp.getMatrixB()));
args.push_back(mt.lookupValue(thisOp.getIdesc()));
args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
using CtaGroupArray = std::array<EnableAShiftArray, 2>;
using IsATensorArray = std::array<CtaGroupArray, 2>;
using HasScaleInputDArray = std::array<IsATensorArray, 2>;
using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
// [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
{ // without diable output lane
{{// without scale input D
{{
// shared
{{// cg1
{llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
// cg2
{llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
{{// tensor
{
// cg1
llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
},
{
// cg2
llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
}}},
}},
// with scale input D
{{ // shared
{{// cg1
{llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
notIntrinsic},
// cg2
{llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
notIntrinsic}}},
{{// tensor
{
// cg1
llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
},
{
// cg2
llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
}}}}}}},
// with disable output lane
{{ // without scale input D
{{ // shared
{{// cg1
{llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
notIntrinsic},
// cg2
{llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
notIntrinsic}}},
{{// cg1
{
llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
},
// cg2
{
llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
}}}}},
// with scale input D
{{ // shared
{{// cg1
{llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
notIntrinsic},
// cg2
{llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
notIntrinsic}}},
// tensor
{{// cg1
{llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
// cg2
{
llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
}}}}}}}}};
llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
bool hasScaleInputD = ScaleInputD != nullptr;
llvm::Value *DisableOutputLane =
mt.lookupValue(thisOp.getDisableOutputLane());
bool hasDisableOutputLane = DisableOutputLane != nullptr;
unsigned ctaGroup =
static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
llvm::Intrinsic::ID ID =
tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
[ctaGroup - 1][thisOp.getAShift()];
assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
if (hasScaleInputD)
args.push_back(ScaleInputD);
if (hasDisableOutputLane)
args.push_back(DisableOutputLane);
args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
if (!hasDisableOutputLane)
args.push_back(builder.getInt32(ctaGroup));
args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
return {ID, args};
}
LogicalResult Tcgen05MMASparseOp::verify() {
return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
getDisableOutputLane(), getCtaGroup(), getAShift(),
getCollectorOp(), getLoc());
}
//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma.block_scale functions
//===----------------------------------------------------------------------===//
mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getMatrixD()));
llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
bool isATensor = isa<llvm::PointerType>(A->getType());
args.push_back(A);
args.push_back(mt.lookupValue(thisOp.getMatrixB()));
args.push_back(mt.lookupValue(thisOp.getIdesc()));
args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
args.push_back(mt.lookupValue(thisOp.getScaleA()));
args.push_back(mt.lookupValue(thisOp.getScaleB()));
args.push_back(builder.getInt32(
static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
auto kind = thisOp.getKind();
auto blockScale = thisOp.getBlockScale();
llvm::Intrinsic::ID ID = [&]() {
if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
} else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
}
} else if (kind == NVVM::MMABlockScaleKind::MXF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor
? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
: llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
} else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
}
} else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
} else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
return isATensor
? llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
}
}
llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
}();
return {ID, args};
}
static LogicalResult verifyTcgen05MMABlockScaleOp(
NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind,
NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
kind == MMABlockScaleKind::MXF4NVF4)
return emitError(loc, "mxf4nvf4 requires block scale attribute");
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
kind != MMABlockScaleKind::MXF4NVF4)
return emitError(loc,
llvm::formatv("{} kind does not support block16 attribute",
stringifyEnum(kind)));
return success();
}
LogicalResult Tcgen05MMABlockScaleOp::verify() {
return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
getBlockScale(), getLoc());
}
//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma.sp.block_scale functions
//===----------------------------------------------------------------------===//
mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getMatrixD()));
llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
bool isATensor = isa<llvm::PointerType>(A->getType());
args.push_back(A);
args.push_back(mt.lookupValue(thisOp.getMatrixB()));
args.push_back(mt.lookupValue(thisOp.getIdesc()));
args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
args.push_back(mt.lookupValue(thisOp.getScaleA()));
args.push_back(mt.lookupValue(thisOp.getScaleB()));
args.push_back(builder.getInt32(
static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
auto kind = thisOp.getKind();
auto blockScale = thisOp.getBlockScale();
llvm::Intrinsic::ID ID = [&]() {
if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
} else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
}
} else if (kind == NVVM::MMABlockScaleKind::MXF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
} else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
}
} else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
} else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
return isATensor
? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
}
}
llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
}();
return {ID, args};
}
LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
getBlockScale(), getLoc());
}
//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma.ws functions
//===----------------------------------------------------------------------===//
mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getMatrixD()));
llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
bool isATensor = isa<llvm::PointerType>(A->getType());
args.push_back(A);
args.push_back(mt.lookupValue(thisOp.getMatrixB()));
args.push_back(mt.lookupValue(thisOp.getIdesc()));
args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
mlir::Value ZeroColMask = thisOp.getZeroColMask();
llvm::Intrinsic::ID ID = notIntrinsic;
if (ZeroColMask) {
args.push_back(mt.lookupValue(ZeroColMask));
ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
: llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
} else
ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
: llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
return {ID, args};
}
//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma.ws.sp functions
//===----------------------------------------------------------------------===//
mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getMatrixD()));
llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
bool isATensor = isa<llvm::PointerType>(A->getType());
args.push_back(A);
args.push_back(mt.lookupValue(thisOp.getMatrixB()));
args.push_back(mt.lookupValue(thisOp.getIdesc()));
args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
mlir::Value ZeroColMask = thisOp.getZeroColMask();
llvm::Intrinsic::ID ID = notIntrinsic;
if (ZeroColMask) {
args.push_back(mt.lookupValue(ZeroColMask));
ID = isATensor
? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
: llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
} else
ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
: llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
return {ID, args};
}
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
// TODO: This should be the llvm.nvvm dialect once this is supported.
void NVVMDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
>();
// Support unknown operations because not all NVVM operations are
// registered.
allowUnknownOperations();
declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
}
LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
StringAttr attrName = attr.getName();
// Kernel function attribute should be attached to functions.
if (attrName == NVVMDialect::getKernelFuncAttrName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) {
return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
<< "' attribute attached to unexpected op";
}
}
// If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
// dim
if (attrName == NVVMDialect::getMaxntidAttrName() ||
attrName == NVVMDialect::getReqntidAttrName() ||
attrName == NVVMDialect::getClusterDimAttrName()) {
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
if (!values || values.empty() || values.size() > 3) {
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
}
}
// If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
// attribute
if (attrName == NVVMDialect::getMinctasmAttrName() ||
attrName == NVVMDialect::getMaxnregAttrName() ||
attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
return op->emitError()
<< "'" << attrName << "' attribute must be integer constant";
}
}
// blocksareclusters must be used along with reqntid and cluster_dim
if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
!op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
return op->emitError()
<< "'" << attrName << "' attribute must be used along with "
<< "'" << NVVMDialect::getReqntidAttrName() << "' and "
<< "'" << NVVMDialect::getClusterDimAttrName() << "'";
}
}
return success();
}
LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute argAttr) {
auto funcOp = dyn_cast<FunctionOpInterface>(op);
if (!funcOp)
return success();
bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
StringAttr attrName = argAttr.getName();
if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
if (!isKernel) {
return op->emitError()
<< "'" << attrName
<< "' attribute must be present only on kernel arguments";
}
if (!isa<UnitAttr>(argAttr.getValue()))
return op->emitError() << "'" << attrName << "' must be a unit attribute";
if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
return op->emitError()
<< "'" << attrName
<< "' attribute requires the argument to also have attribute '"
<< LLVM::LLVMDialect::getByValAttrName() << "'";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// NVVM Address Space Attr
//===----------------------------------------------------------------------===//
unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
return static_cast<unsigned>(getValue());
}
bool NVVMMemorySpaceAttr::isValidLoad(
Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
dataLayout, emitError);
}
bool NVVMMemorySpaceAttr::isValidStore(
Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
dataLayout, emitError);
}
bool NVVMMemorySpaceAttr::isValidAtomicOp(
ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once `ptr.atomic_rmw` is implemented.
assert(false && "unimplemented, see TODO in the source.");
return false;
}
bool NVVMMemorySpaceAttr::isValidAtomicXchg(
Type type, ptr::AtomicOrdering successOrdering,
ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
assert(false && "unimplemented, see TODO in the source.");
return false;
}
bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once the `ptr.addrspace_cast` op is added to the
// dialect.
assert(false && "unimplemented, see TODO in the source.");
return false;
}
bool NVVMMemorySpaceAttr::isValidPtrIntCast(
Type intLikeTy, Type ptrLikeTy,
function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once the int-cast ops are added to the `ptr`
// dialect.
assert(false && "unimplemented, see TODO in the source.");
return false;
}
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
LogicalResult
NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int optLevel, StringRef triple, StringRef chip,
StringRef features, DictionaryAttr flags,
ArrayAttr files, bool verifyTarget) {
if (optLevel < 0 || optLevel > 3) {
emitError() << "The optimization level must be a number between 0 and 3.";
return failure();
}
if (triple.empty()) {
emitError() << "The target triple cannot be empty.";
return failure();
}
if (chip.empty()) {
emitError() << "The target chip cannot be empty.";
return failure();
}
if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
return mlir::isa_and_nonnull<StringAttr>(attr);
})) {
emitError() << "All the elements in the `link` array must be strings.";
return failure();
}
return success();
}
LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
if (!getVerifyTarget())
return success();
auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
if (!gpuModuleOp) {
return emitError(gpuModule->getLoc(),
"NVVM target attribute must be attached to a GPU module");
}
const NVVMCheckSMVersion targetSMVersion =
NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
if (!targetSMVersion.isMinimumSMVersion()) {
return emitError(gpuModule->getLoc(),
"Minimum NVVM target SM version is sm_20");
}
if (gpuModuleOp
->walk([&](Operation *op) {
if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
const NVVMCheckSMVersion requirement =
reqOp.getRequiredMinSMVersion();
if (!requirement.isCompatibleWith(targetSMVersion)) {
op->emitOpError() << "is not supported on " << getChip();
return WalkResult::interrupt();
}
}
return WalkResult::advance();
})
.wasInterrupted())
return failure();
return success();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"