blob: 75888ba79447a6fd67e6bc16ee93bd456371cc85 [file] [log] [blame] [edit]
//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Types.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace xevm;
namespace {
struct LLVMFuncAttributeOptions {
bool isConvergent = false;
bool isNoUnwind = false;
bool isWillReturn = false;
LLVM::MemoryEffectsAttr memEffectsAttr{};
};
static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
false, true, false, {}};
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
false, true, true, {}};
static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
true, true, true, {}};
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
return TypeSwitch<Type, std::string>(ty)
.Case([isUnsigned](VectorType ty) -> std::string {
return "Dv" + std::to_string(ty.getNumElements()) + "_" +
getTypeMangling(ty.getElementType(), isUnsigned);
})
.Case([](Float16Type) -> std::string { return "Dh"; })
.Case([](Float32Type) -> std::string { return "f"; })
.Case([](Float64Type) -> std::string { return "d"; })
.Case([isUnsigned](IntegerType ty) -> std::string {
switch (ty.getWidth()) {
case 8:
return isUnsigned ? "h" : "c";
case 16:
return isUnsigned ? "t" : "s";
case 32:
return isUnsigned ? "j" : "i";
case 64:
return isUnsigned ? "m" : "l";
default:
llvm_unreachable("unhandled integer type");
}
})
.DefaultUnreachable("unhandled type for mangling");
}
std::string mangle(StringRef baseName, ArrayRef<Type> types,
ArrayRef<bool> isUnsigned = {}) {
assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
"Signedness info doesn't match");
std::string s;
llvm::raw_string_ostream os(s);
llvm::SmallDenseMap<Type, unsigned> substitutions;
os << "_Z" << baseName.size() << baseName;
for (auto [idx, type] : llvm::enumerate(types)) {
auto it = substitutions.find(type);
if (it != substitutions.end()) {
os << "S";
// First substitution is `S_`, second is `S0_`, and so on.
if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
os << firstIdx - 1;
os << "_";
} else {
if (!type.isIntOrFloat())
substitutions[type] = substitutions.size();
os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
}
}
return os.str();
}
static int32_t getL1CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1C_L2C_L3C:
control = 1;
break;
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3UC:
case LoadCacheControl::L1S_L2C_L3C:
control = 2;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 3;
break;
default:
break;
}
return control;
}
static int32_t getL1CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1WT_L2WB_L3WB:
control = 1;
break;
case StoreCacheControl::L1WB_L2UC_L3UC:
case StoreCacheControl::L1WB_L2WB_L3UC:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
case StoreCacheControl::L1S_L2UC_L3UC:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3UC:
case StoreCacheControl::L1S_L2WB_L3WB:
control = 3;
break;
default:
break;
}
return control;
}
static int32_t getL3CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3C:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3C:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3C:
control = 1;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 3;
break;
default:
break;
}
return control;
}
static int32_t getL3CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3WB:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3WB:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3WB:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
default:
break;
}
return control;
}
static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
return op.getCacheControl();
}
static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
return op.getCacheControl();
}
static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
return op.getCacheControl();
}
static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
return op.getCacheControl();
}
static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
return op.getCacheControl();
}
static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
return op.getCacheControl();
}
static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
if (!attr)
return std::nullopt;
return std::optional<LoadCacheControl>(attr.getValue());
}
return std::nullopt;
}
static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
if (!attr)
return std::nullopt;
return std::optional<StoreCacheControl>(attr.getValue());
}
return std::nullopt;
}
template <typename OpType>
int32_t getL1CacheControl(OpType op) {
return getL1CacheControl(*getCacheControl(op));
}
template <typename OpType>
int32_t getL3CacheControl(OpType op) {
return getL3CacheControl(*getCacheControl(op));
}
template <typename OpType>
static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
if (!getCacheControl(op))
return {};
constexpr int32_t decorationCacheControlArity{3};
constexpr int32_t loadCacheControlKey{6442};
constexpr int32_t storeCacheControlKey{6443};
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
std::is_same_v<OpType, BlockPrefetch2dOp> ||
std::is_same_v<OpType, LLVM::LoadOp> ||
std::is_same_v<OpType, BlockLoadOp> ||
std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
controlKey, 0, getL1CacheControl<OpType>(op)};
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
controlKey, 1, getL3CacheControl<OpType>(op)};
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
return rewriter.getArrayAttr(combinedAttrs);
}
//===----------------------------------------------------------------------===//
// Cache control annotation utilities
//
// Instead of attaching cache control as MLIR attributes and handling them
// during LLVM translation, we directly emit llvm.intr.ptr.annotation op in
// MLIR.
//===----------------------------------------------------------------------===//
/// Build one cache-control payload string per attribute.
///
/// Each Attribute is expected to be an ArrayAttr of 3 IntegerAttr values:
/// [SPIR-V decoration token, cache level, cache control value]
///
/// A single entry produces a string like: {6442:"0,1"}
/// where the quote characters (0x22) will appear as \22 in LLVM IR textual
/// form.
static SmallVector<std::string>
buildCacheControlPayloads(ArrayRef<Attribute> attrs) {
SmallVector<std::string> payloads;
llvm::StringMap<bool> seen;
for (Attribute a : attrs) {
auto arr = dyn_cast<ArrayAttr>(a);
if (!arr)
continue;
auto vals = arr.getValue();
assert(vals.size() == 3 &&
"Expected exactly 3 integer values (Token, CacheLevel, "
"ControlValue) in cache control attribute.");
auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
if (!tokenAttr || !secondAttr || !thirdAttr)
continue;
// Produce: {SPIR-V decoration token:"L1 cache control,L3 cache control"}
// The quote char (0x22) is embedded literally; LLVM IR prints it as \22.
std::string entry = llvm::formatv("'{'{0}:\"{1},{2}\"'}'",
tokenAttr.getValue().getZExtValue(),
secondAttr.getValue().getZExtValue(),
thirdAttr.getValue().getZExtValue());
// Deduplicate identical annotations.
if (!seen.insert({entry, true}).second)
continue;
payloads.push_back(std::move(entry));
}
return payloads;
}
/// Counter for generating unique global variable names.
static std::atomic<uint64_t> globalNameCounter{0};
/// Get or create a global metadata string and return a !llvm.ptr<1> value
/// pointing to it. The AddressOfOp is created at the current rewriter
/// insertion point; the GlobalOp is created at the module start.
static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
Operation *moduleOp, Location loc,
StringRef value, StringRef nameHint) {
// Build null-terminated string.
std::string strWithNull = value.str();
strWithNull.push_back('\0');
StringRef strRef(strWithNull.data(), strWithNull.size());
auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
// Search for an existing global with the same content.
for (auto &op : moduleOp->getRegion(0).front()) {
if (auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
if (!existingGlobal.getSection() ||
*existingGlobal.getSection() != "llvm.metadata")
continue;
if (auto strAttr =
dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
if (strAttr.getValue() == strRef) {
return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
existingGlobal.getSymName());
}
}
}
}
// Create new global at module start.
auto i8Type = rewriter.getI8Type();
auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
std::string globalName =
llvm::formatv("{0}.{1}", nameHint,
globalNameCounter.fetch_add(1, std::memory_order_relaxed))
.str();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front());
auto globalOp =
LLVM::GlobalOp::create(rewriter, loc, arrayType,
/*isConstant=*/true, LLVM::Linkage::Private,
globalName, rewriter.getStringAttr(strRef));
globalOp.setSection(StringRef("llvm.metadata"));
globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
globalOp.setAlignment(1);
globalOp.setAddrSpace(1);
}
// InsertionGuard restores the original insertion point here.
return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
}
/// Annotate a pointer value with cache control metadata by emitting chained
/// `llvm.intr.ptr.annotation` ops (LLVM::PtrAnnotation).
///
/// This is the MLIR-level equivalent of handleDecorationCacheControl() from
/// the LLVM translation layer. For each cache control attribute, it emits:
///
/// %ann = llvm.intr.ptr.annotation %ptr, @".str.cachecontrol.N",
/// @".str.file.N", 0, null : !llvm.ptr<AS>
///
/// Multiple annotations are chained: the result of each annotation op is
/// fed as the pointer input to the next one.
///
/// \param rewriter The pattern rewriter.
/// \param loc Source location for created ops.
/// \param ptr The pointer value to annotate.
/// \param cacheControls The cache control ArrayAttr (from
/// getCacheControlMetadata).
/// \param moduleOp The enclosing module (for creating globals).
/// \returns The annotated pointer value (or the original ptr if no
/// annotations).
static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
Location loc, Value ptr,
ArrayAttr cacheControls,
Operation *moduleOp) {
SmallVector<std::string> payloads =
buildCacheControlPayloads(cacheControls.getValue());
if (payloads.empty())
return ptr;
auto ptrType = cast<LLVM::LLVMPointerType>(ptr.getType());
auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
auto i32Ty = rewriter.getI32Type();
// Create shared constants for all annotations on this pointer.
Value fileStr =
createMetadataStringPtr(rewriter, moduleOp, loc, "", ".str.file");
Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
// Chain: each annotation takes the result of the previous one as its
// pointer operand.
Value curPtr = ptr;
for (const std::string &payload : payloads) {
Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
".str.cachecontrol");
auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
annStr, fileStr, lineVal, nullAS1);
curPtr = annOp.getResult();
}
return curPtr;
}
/// Helper to apply cache control annotation on a pointer operand of a call.
/// Replaces the pointer argument of the call with an annotated version.
///
/// For operations that produce a call (like block load/store/prefetch), the
/// pointer is typically the first argument. This function:
/// 1. Builds the annotation chain on the pointer.
/// 2. Replaces the pointer operand in the provided args list.
///
/// \param rewriter The pattern rewriter.
/// \param loc Source location.
/// \param ptr The original pointer value (first arg to the call).
/// \param cacheControls The cache control metadata.
/// \param moduleOp The enclosing module.
/// \param args The argument list (modified in place: args[ptrIdx] is
/// replaced).
/// \param ptrIdx Index of the pointer in the args list (default 0).
template <typename OpType>
static void
applyCacheControlAnnotation(ConversionPatternRewriter &rewriter, Location loc,
OpType op, SmallVectorImpl<Value> &args,
Operation *moduleOp, unsigned ptrIdx = 0) {
std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata(rewriter, op);
if (!optCacheControls)
return;
Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
*optCacheControls, moduleOp);
args[ptrIdx] = annotatedPtr;
}
//===----------------------------------------------------------------------===//
// End cache control annotation utilities
//===----------------------------------------------------------------------===//
static LLVM::CallOp createDeviceFunctionCall(
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
ArrayRef<Type> argTypes, ArrayRef<Value> args,
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
Location loc = op->getLoc();
auto funcOpRes =
LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
assert(!failed(funcOpRes));
LLVM::LLVMFuncOp funcOp = funcOpRes.value();
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
funcOp.setConvergent(funcAttributeOptions.isConvergent);
funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
if (funcAttributeOptions.memEffectsAttr)
funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
for (auto [idx, attrName] : paramAttrs)
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
callOp->setAttrs(funcOp->getAttrs());
return callOp;
}
class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getC()) {
return rewriter.notifyMatchFailure(op, "OCL requires C operand");
}
auto precisionA = op.getTypes().getA();
auto precisionB = op.getTypes().getB();
auto precisionC = op.getTypes().getC();
auto precisionD = op.getTypes().getD();
if (precisionC != precisionD) {
return rewriter.notifyMatchFailure(op, "type of C and D need to match");
}
if (precisionC != xevm::ElemType::S32 &&
precisionC != xevm::ElemType::F32 &&
precisionC != xevm::ElemType::F16 &&
precisionC != xevm::ElemType::BF16) {
return rewriter.notifyMatchFailure(
op, "type of C and D must be S32, F32, F16 or BF16");
}
if (precisionA == xevm::ElemType::S32 ||
precisionA == xevm::ElemType::F32) {
return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
}
if (precisionB == xevm::ElemType::S32 ||
precisionB == xevm::ElemType::F32) {
return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
}
constexpr uint32_t bitWidthPackedA{16};
constexpr uint32_t bitWidthPackedB{32};
auto loc = op.getLoc();
auto castIfNeeded = [&](Value val, Type packedType) -> Value {
VectorType origTy = cast<VectorType>(val.getType());
const uint32_t vecBitSize =
origTy.getNumElements() *
origTy.getElementType().getIntOrFloatBitWidth();
VectorType newTy = VectorType::get(
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
if (origTy != newTy)
val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
return val;
};
Value a = op.getA();
Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
? cast<Type>(rewriter.getF32Type())
: rewriter.getIntegerType(bitWidthPackedA);
a = castIfNeeded(a, packedAType);
Value b = op.getB();
Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
? cast<Type>(rewriter.getF32Type())
: rewriter.getIntegerType(bitWidthPackedB);
b = castIfNeeded(b, packedBType);
Value c = op.getC();
VectorType cOrigTy = cast<VectorType>(c.getType());
VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
// OCL builtins encode bfloat16 as int16
VectorType cTy =
cOrigTy.getElementType().isBF16()
? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
: cOrigTy;
VectorType resTy = cTy;
if (cOrigTy != cTy)
c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
constexpr int32_t systolicDepth{8};
std::string fnName =
llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
stringifyElemType(op.getTypes().getA()).str(),
stringifyElemType(op.getTypes().getB()).str(),
systolicDepth *
getNumOperandsPerDword(op.getTypes().getA()))
.str();
SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
fnName = mangle(fnName, argTypes);
SmallVector<Value> args{a, b, c};
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
/*errnoMem=*/LLVM::ModRefInfo::NoModRef,
/*targetMem0=*/LLVM::ModRefInfo::NoModRef,
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
Value result =
createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
funcAttrs, op.getOperation())
->getResult(0);
if (resOrigTy != resTy)
result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
rewriter.replaceOp(op, result);
return success();
}
private:
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
switch (pTy) {
case xevm::ElemType::TF32:
return 1;
case xevm::ElemType::BF16:
case xevm::ElemType::F16:
return 2;
case xevm::ElemType::U8:
case xevm::ElemType::S8:
return 4;
default:
llvm_unreachable("unsupported xevm::ElemType");
}
}
};
class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
Value one =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
SmallVector<Value> args{op.getPtr(), one};
// Annotate pointer with cache control before passing to the call.
applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
/*ptrIdx=*/0);
SmallVector<Type> argTypes;
for (auto arg : args)
argTypes.push_back(arg.getType());
auto funcAttr = noUnwindAttrs;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
/*errnoMem=*/LLVM::ModRefInfo::NoModRef,
/*targetMem0=*/LLVM::ModRefInfo::NoModRef,
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr.memEffectsAttr = memAttr;
createDeviceFunctionCall(rewriter, fnName,
LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, {}, funcAttr, op.getOperation());
rewriter.eraseOp(op);
return success();
}
};
class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
const std::string fnName{"atomic_work_item_fence"};
int memScope, addrSpace;
switch (op.getAddrspace()) {
case xevm::AddrSpace::SHARED:
addrSpace = 1; // CLK_LOCAL_MEM_FENCE
break;
case xevm::AddrSpace::GLOBAL:
addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
break;
default:
// GENERIC is not supported in OpenCL
return rewriter.notifyMatchFailure(
op, "Fence only supports global and shared address spaces.");
}
switch (op.getScope()) {
case xevm::MemScope::WORKGROUP:
memScope = 1;
break;
case xevm::MemScope::DEVICE:
memScope = 2;
break;
default:
// CLUSTER and SYSTEM are not supported in OpenCL
return rewriter.notifyMatchFailure(
op, "Fence only supports workgroup and device memory scopes.");
}
Type i32Type = rewriter.getI32Type();
Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
Value memScopeConst =
LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
Value addrSpaceConst =
LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
SmallVector<Type> argTypes{3, i32Type};
createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, {}, noUnwindAttrs,
op.getOperation());
rewriter.eraseOp(op);
return success();
}
};
template <typename OpType>
class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
auto loc = op.getLoc();
auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
VectorType vecType;
bool packReg = false;
bool transpose = false;
if constexpr (isLoad) {
vecType = op.getRes().getType();
packReg = op.getPackRegister();
transpose = op.getTranspose();
} else if constexpr (!isPrefetch) {
vecType = op.getStoredVal().getType();
}
auto i32Type = rewriter.getI32Type();
Value byteCoord =
LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
byteCoord = LLVM::InsertElementOp::create(
rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
byteCoord = LLVM::InsertElementOp::create(
rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
// Annotate pointer (args[0]) with cache control before the call.
applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
/*ptrIdx=*/0);
SmallVector<Type> retTypes;
Value spvLoadDstPtr;
std::string funcName{"intel_sub_group_2d_block_"};
std::string bitWidthId;
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
if constexpr (isPrefetch) { // Prefetch
funcName += "prefetch";
paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
/*errnoMem=*/LLVM::ModRefInfo::NoModRef,
/*targetMem0=*/LLVM::ModRefInfo::NoModRef,
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr = noUnwindAttrs;
funcAttr.memEffectsAttr = memAttr;
} else {
auto vecElemType = vecType.getElementType();
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
vecType.getNumElements());
auto dstOrSrcPtr = LLVM::AllocaOp::create(
rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
vecElemType, numElems);
args.push_back(dstOrSrcPtr);
if constexpr (isLoad) { // Load
funcName += "read";
bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
if (packReg)
funcName += "_transform";
else if (transpose)
funcName += "_transpose";
spvLoadDstPtr = dstOrSrcPtr;
retTypes.push_back(vecType);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
};
} else { // Store
funcName += "write";
bitWidthId = (vecElemBitWidth == 32)
? "j"
: ((vecElemBitWidth == 16) ? "t" : "h");
LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
};
}
}
funcName =
llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
.str();
std::string prefetchCode("");
if (!isPrefetch)
prefetchCode += "P";
funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
funcName, prefetchCode, bitWidthId)
.str();
SmallVector<Type> argTypes;
for (auto arg : args) {
argTypes.push_back(arg.getType());
}
createDeviceFunctionCall(
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, paramAttrs, funcAttr, op.getOperation());
if constexpr (isLoad)
rewriter.replaceOp(
op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
else
rewriter.eraseOp(op);
return success();
}
};
template <typename OpType>
class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
auto loc = op.getLoc();
auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
// Get OpenCL function name
// https://registry.khronos.org/OpenCL/extensions/
// intel/cl_intel_subgroup_local_block_io.html
std::string funcName{"intel_sub_group_block_"};
// Value or Result type can be vector or scalar
Type valOrResTy;
if constexpr (isStore) {
funcName += "write_u";
valOrResTy = op.getVal().getType();
} else {
funcName += "read_u";
valOrResTy = op.getType();
}
// Get element type of the vector/scalar
VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
funcName += getTypeMangling(elemType);
if (vecTy)
funcName += std::to_string(vecTy.getNumElements());
SmallVector<Type, 2> argTypes{};
// XeVM BlockLoad/StoreOp always use signless integer types
// but OpenCL builtins expect unsigned types
// use unsigned types for mangling
SmallVector<bool, 2> isUnsigned{};
// arg0: pointer to the src/dst address
// arg1 - only if store : vector to store
// Prepare arguments
SmallVector<Value, 2> args{};
args.push_back(op.getPtr());
argTypes.push_back(op.getPtr().getType());
isUnsigned.push_back(true);
// Annotate pointer (args[0]) with cache control.
applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
/*ptrIdx=*/0);
// Update argTypes[0] in case the pointer type changed (it shouldn't
// change type, but the value is now the annotated pointer).
argTypes[0] = args[0].getType();
Type retType;
if constexpr (isStore) {
args.push_back(op.getVal());
argTypes.push_back(op.getVal().getType());
isUnsigned.push_back(true);
retType = LLVM::LLVMVoidType::get(rewriter.getContext());
} else {
retType = valOrResTy;
}
funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
"PU3AS" +
std::to_string(op.getPtr().getType().getAddressSpace());
funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
if constexpr (isStore)
funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
LLVM::CallOp call =
createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
{}, funcAttr, op.getOperation());
if constexpr (isStore)
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, call->getResult(0));
return success();
}
};
template <typename OpType>
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op->hasAttr("cache_control"))
return failure();
auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata(rewriter, op);
if (!optCacheControls) {
op->removeAttr("cache_control");
return success();
}
// Determine which operand is the pointer.
constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
unsigned ptrIdx = isStore ? 1 : 0;
Value ptr = op->getOperand(ptrIdx);
// Emit annotation intrinsic calls on the pointer.
Value annotatedPtr = annotatePtrWithCacheControl(
rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
// Replace the pointer operand with the annotated one.
op->setOperand(ptrIdx, annotatedPtr);
op->removeAttr("cache_control");
return success();
}
};
//===----------------------------------------------------------------------===//
// GPU index id operations
//===----------------------------------------------------------------------===//
/*
// Launch Config ops
// dimidx - x, y, z - is fixed to i32
// return type is set by XeVM type converter
// get_local_id
xevm::WorkitemIdXOp;
xevm::WorkitemIdYOp;
xevm::WorkitemIdZOp;
// get_local_size
xevm::WorkgroupDimXOp;
xevm::WorkgroupDimYOp;
xevm::WorkgroupDimZOp;
// get_group_id
xevm::WorkgroupIdXOp;
xevm::WorkgroupIdYOp;
xevm::WorkgroupIdZOp;
// get_num_groups
xevm::GridDimXOp;
xevm::GridDimYOp;
xevm::GridDimZOp;
// get_global_id : to be added if needed
*/
// Helpers to get the OpenCL function name and dimension argument for each op.
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
return {"get_local_id", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
return {"get_local_id", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
return {"get_local_id", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
return {"get_local_size", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
return {"get_local_size", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
return {"get_local_size", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
return {"get_group_id", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
return {"get_group_id", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
return {"get_group_id", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
return {"get_num_groups", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
return {"get_num_groups", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
return {"get_num_groups", 2};
}
/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
/// a constant argument for the dimension - x, y or z.
template <typename OpType>
class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto [baseName, dim] = getConfig(op);
Type dimTy = rewriter.getI32Type();
Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
static_cast<int64_t>(dim));
std::string func = mangle(baseName, {dimTy}, {true});
Type resTy = op.getType();
auto call =
createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
noUnwindWillReturnAttrs, op.getOperation());
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
/*errnoMem=*/noModRef,
/*targetMem0=*/noModRef,
/*targetMem1=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
}
};
/*
// Subgroup ops
// get_sub_group_local_id
xevm::LaneIdOp;
// get_sub_group_id
xevm::SubgroupIdOp;
// get_sub_group_size
xevm::SubgroupSizeOp;
// get_num_sub_groups : to be added if needed
*/
// Helpers to get the OpenCL function name for each op.
static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
static StringRef getConfig(xevm::SubgroupSizeOp) {
return "get_sub_group_size";
}
template <typename OpType>
class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::string func = mangle(getConfig(op).str(), {});
Type resTy = op.getType();
auto call =
createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
noUnwindWillReturnAttrs, op.getOperation());
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
/*errnoMem=*/noModRef,
/*targetMem0=*/noModRef,
/*targetMem1=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
}
};
class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
auto addrSpace = ptrType.getAddressSpace();
if (addrSpace != 3)
return failure();
auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
if (!symTable)
return failure();
Block *moduleBody;
if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
moduleBody = mod.getBody();
} else if (gpu::GPUModuleOp gpuMod =
dyn_cast<gpu::GPUModuleOp>(*symTable)) {
moduleBody = gpuMod.getBody();
} else {
return failure();
}
auto val = op.getArraySize();
APInt cst;
if (!matchPattern(val, m_ConstantInt(&cst)))
return failure();
auto loc = op.getLoc();
auto globalType = LLVM::LLVMArrayType::get(
rewriter.getContext(), op.getElemType(), cst.getZExtValue());
LLVM::GlobalOp globalVar;
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleBody);
auto alignment = op.getAlignment();
globalVar = LLVM::GlobalOp::create(
rewriter, loc, globalType, /*isConstant=*/false,
/*linkage=*/LLVM::Linkage::Internal,
/*name=*/std::string("__global_alloca_") +
std::to_string(getNextGlobalIdx()),
/*value=*/Attribute(),
/*alignment=*/alignment ? *alignment : 0, /*addrSpace=*/addrSpace);
}
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
return success();
}
private:
static unsigned getNextGlobalIdx() {
static unsigned globalIdx = 0;
return globalIdx++;
}
};
// Checks if shufflevector is used as a way to extract a contiguous slice
// from a vector.
// - source vector V1 and V2 are the same vector.
// - mask size is not greater than the source vector size
// - mask values represent a sequence of consecutive increasing numbers
// that stay in bounds of the source vector when used for indexing.
static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
if (op.getV1() != op.getV2())
return false;
auto maskAttr = op.getMask();
int64_t maskSize = static_cast<int64_t>(maskAttr.size());
int64_t sourceSize = op.getV1().getType().getNumElements();
if (maskSize > sourceSize)
return false;
int64_t firstIndex = maskAttr[0];
for (int64_t i = 1; i < maskSize; ++i) {
int64_t index = maskAttr[i];
if (index != firstIndex + i)
return false;
if (index >= sourceSize)
return false;
}
return true;
}
// Input vector of a shuffle vector op extracting a contiguous slice is an
// illegal vector in SPIRV kernel if the vector size is > 16 elements.
// To legalize this case, keep applying the following transformations until no
// more match:
// 1. keep hoisting the shuffle vector op past unary element-wise operations
// start with fpext, fptrunc and bitcast for now.
// 2. merge with another shuffle vector op
// 3. merge with load as a smaller load
class HandleVectorExtractPattern
: public OpRewritePattern<LLVM::ShuffleVectorOp> {
using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
void initialize() { setHasBoundedRewriteRecursion(); }
LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
PatternRewriter &rewriter) const override {
if (!isExtractingContiguousSlice(op))
return failure();
auto mask = op.getMask();
auto loc = op.getLoc();
auto ty = op.getType();
// Check source operand to determine rewrite pattern.
auto src = op.getV1();
// 1. Hoist past unary element-wise operations
if (auto srcOp = src.getDefiningOp()) {
if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
Value srcInput = srcOp->getOperand(0);
// Create new shuffle vector op with unary input as source.
auto srcVecTy = dyn_cast<VectorType>(srcInput.getType());
auto newShuffleVecTy =
VectorType::get(mask.size(), srcVecTy.getElementType());
auto newShuffle = LLVM::ShuffleVectorOp::create(
rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
// Create new unary op with new shuffle as input.
Value newUnaryOp;
if (isa<LLVM::FPExtOp>(srcOp)) {
newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
} else {
newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
}
rewriter.replaceOp(op, newUnaryOp);
} else if (isa<LLVM::BitcastOp>(srcOp)) {
Value srcInput = srcOp->getOperand(0);
// Create new shuffle vector op with unary input as source.
auto srcInputVecTy = dyn_cast<VectorType>(srcInput.getType());
auto srcInputSize = srcInputVecTy.getNumElements();
auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
auto srcResSize = srcResVecTy.getNumElements();
auto maskSize = static_cast<int32_t>(mask.size());
if (srcInputSize > srcResSize) {
return failure();
}
if (srcResSize % srcInputSize != 0) {
return failure();
}
auto maskScale = srcResSize / srcInputSize;
if (maskScale != 1) {
if (mask[0] % maskScale != 0) {
return failure();
}
// Create a new mask that maps to the source vector
SmallVector<int32_t> newMask;
int32_t newMaskSize = maskSize / maskScale;
int32_t maskStart = mask[0] / maskScale;
for (int32_t i = 0; i < newMaskSize; ++i) {
newMask.push_back(maskStart + i);
}
mask = newMask;
}
auto newShuffleVecTy =
VectorType::get(srcInputSize, srcInputVecTy.getElementType());
auto newShuffle = LLVM::ShuffleVectorOp::create(
rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
// Create new unary op with new shuffle as input.
auto newBitcast =
LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
rewriter.replaceOp(op, newBitcast);
} else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
// 2. Merge with source shuffle vector op if, the source op is
// also extracting a contigous slice and create a new
// shuffle vector op directly from the source of
// the first shuffle.
auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
if (!isExtractingContiguousSlice(srcShuffle))
return failure();
auto srcMask = srcShuffle.getMask();
SmallVector<int32_t> combinedMask;
for (auto index : mask) {
combinedMask.push_back(srcMask[index]);
}
auto newShuffle = LLVM::ShuffleVectorOp::create(
rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
DenseI32ArrayAttr::get(rewriter.getContext(), combinedMask));
rewriter.replaceOp(op, newShuffle);
} else if (isa<LLVM::LoadOp>(srcOp)) {
// 3. Merge with load as a smaller load
auto loadOp = cast<LLVM::LoadOp>(srcOp);
auto loadPtr = loadOp.getAddr();
auto loadTy = dyn_cast<VectorType>(loadOp.getType());
auto elemTy = loadTy.getElementType();
auto firstIndex = mask[0];
auto newVecTy = VectorType::get(mask.size(), elemTy);
// GEPOp is needed if first index is not zero
if (firstIndex) {
auto newPtr = LLVM::GEPOp::create(
rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(),
loadPtr.getType().getAddressSpace()),
elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
rewriter.replaceOp(op, newLoad);
} else {
auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
rewriter.replaceOp(op, newLoad);
}
} else {
return failure();
}
}
return success();
}
};
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
struct ConvertXeVMToLLVMPass
: public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, XeVMDialect>();
}
void runOnOperation() override {
ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
populateXeVMToLLVMConversionPatterns(target, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
// Apply in-dialect lowerings to handle illegal vectors
{
RewritePatternSet vectorPatterns(&getContext());
vectorPatterns.add<HandleVectorExtractPattern>(&getContext());
GreedyRewriteConfig config{};
// folding can remove ops with temporary attributes used to
// represent LLVM metadata, so disable it here.
// Effectively just this single pattern is applied without any
// op folding patterns from dialects.
config.enableFolding(false);
// config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
// config.setMaxNumRewrites(GreedyRewriteConfig::kNoLimit);
(void)applyPatternsGreedily(getOperation(), std::move(vectorPatterns),
config);
}
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
RewritePatternSet &patterns) {
// some LLVM operations need to be converted.
target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](Operation *op) {
// llvm alloca op with addrspace 3 for OpenCL (Workgroup) is not handled
// properly by SPIRV backend. It needs to be rewritten as a sequence with
// llvm global.
if (isa<LLVM::AllocaOp>(op)) {
LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
auto addrSpace = pTy.getAddressSpace();
return addrSpace != 3;
}
// cache_control attribute should be converted.
return !op->hasAttr("cache_control");
});
target.addIllegalDialect<XeVMDialect>();
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
BlockLoadStore1DToOCLPattern<BlockStoreOp>,
LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
LaunchConfigOpToOCLPattern<GridDimXOp>,
LaunchConfigOpToOCLPattern<GridDimYOp>,
LaunchConfigOpToOCLPattern<GridDimZOp>,
SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
AllocaToGlobalPattern>(patterns.getContext());
}