blob: 9f5585a7014384409bed5000f4b91f81ce250d6b [file] [log] [blame]
//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering of vector operations to XeGPU dialect ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <optional>
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
// Return true if value represents a zero constant.
static bool isZeroConstant(Value val) {
auto constant = val.getDefiningOp<arith::ConstantOp>();
if (!constant)
return false;
return TypeSwitch<Attribute, bool>(constant.getValue())
.Case<FloatAttr>(
[](auto floatAttr) { return floatAttr.getValue().isZero(); })
.Case<IntegerAttr>(
[](auto intAttr) { return intAttr.getValue().isZero(); })
.Default([](auto) { return false; });
}
static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
Operation *op, VectorType vecTy) {
// Validate only vector as the basic vector store and load ops guarantee
// XeGPU-compatible memref source.
unsigned vecRank = vecTy.getRank();
if (!(vecRank == 1 || vecRank == 2))
return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
return success();
}
static LogicalResult transferPreconditions(PatternRewriter &rewriter,
VectorTransferOpInterface xferOp) {
if (xferOp.getMask())
return rewriter.notifyMatchFailure(xferOp,
"Masked transfer is not supported");
auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
if (!srcTy)
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
// Validate further transfer op semantics.
SmallVector<int64_t> strides;
int64_t offset;
if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
return rewriter.notifyMatchFailure(
xferOp, "Buffer must be contiguous in the innermost dimension");
VectorType vecTy = xferOp.getVectorType();
unsigned vecRank = vecTy.getRank();
if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
return rewriter.notifyMatchFailure(
xferOp, "Boundary check is available only for block instructions.");
AffineMap map = xferOp.getPermutationMap();
if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
unsigned numInputDims = map.getNumInputs();
for (AffineExpr expr : map.getResults().take_back(vecRank)) {
auto dim = dyn_cast<AffineDimExpr>(expr);
if (dim.getPosition() < (numInputDims - vecRank))
return rewriter.notifyMatchFailure(
xferOp, "Only the innermost dimensions can be accessed");
}
return success();
}
static xegpu::CreateNdDescOp
createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
Operation::operand_range offsets) {
MemRefType srcTy = src.getType();
auto [strides, offset] = srcTy.getStridesAndOffset();
xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
getAsOpFoldResult(offsets));
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
SmallVector<Value> sourceDims;
unsigned srcRank = srcTy.getRank();
for (unsigned i = 0; i < srcRank; ++i)
sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
SmallVector<int64_t> constOffsets;
SmallVector<Value> dynOffsets;
for (Value offset : offsets) {
std::optional<int64_t> staticVal = getConstantIntValue(offset);
if (!staticVal)
dynOffsets.push_back(offset);
constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
}
SmallVector<Value> dynShapes;
for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
if (shape == ShapedType::kDynamic)
dynShapes.push_back(sourceDims[idx]);
}
// Compute strides in reverse order.
SmallVector<Value> dynStrides;
Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Last stride is guaranteed to be static and unit.
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
accStride =
arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
if (strides[i] == ShapedType::kDynamic)
dynStrides.push_back(accStride);
}
std::reverse(dynStrides.begin(), dynStrides.end());
ndDesc = xegpu::CreateNdDescOp::create(
rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
DenseI64ArrayAttr::get(rewriter.getContext(), strides));
}
return ndDesc;
}
// Adjusts the strides of a memref according to a given permutation map for
// vector operations.
//
// This function updates the innermost strides in the `strides` array to
// reflect the permutation specified by `permMap`. The permutation is computed
// using the inverse and broadcasting-aware version of the permutation map,
// and is applied to the relevant strides. This ensures that memory accesses
// are consistent with the logical permutation of vector elements.
//
// Example:
// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
// 0]), then after calling this function, the last two strides will be
// swapped:
// Original strides: [s0, s1, s2, s3]
// After permutation: [s0, s1, s3, s2]
//
static void adjustStridesForPermutation(AffineMap permMap,
SmallVectorImpl<Value> &strides) {
AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap);
SmallVector<unsigned> perms;
invMap.isPermutationOfMinorIdentityWithBroadcasting(perms);
SmallVector<int64_t> perms64(perms.begin(), perms.end());
strides = applyPermutation(strides, perms64);
}
// Computes memory strides and a memref offset for vector transfer operations,
// handling both static and dynamic memrefs while applying permutation
// transformations for XeGPU lowering.
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
vector::GatherOp, vector::ScatterOp>::value>>
static std::pair<SmallVector<Value>, Value>
computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
SmallVector<Value> strides;
Value baseMemref = xferOp.getBase();
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
Location loc = xferOp.getLoc();
Value offsetVal = nullptr;
if (memrefType.hasStaticShape()) {
int64_t offset;
SmallVector<int64_t> intStrides;
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
return {{}, offsetVal};
bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
return ShapedType::isDynamic(strideVal);
});
if (!hasDynamicStrides)
for (int64_t s : intStrides)
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
if (!ShapedType::isDynamic(offset))
offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
}
if (strides.empty() || !offsetVal) {
// For dynamic shape memref, use memref.extract_strided_metadata to get
// stride values
unsigned rank = memrefType.getRank();
Type indexType = rewriter.getIndexType();
// Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
// size0, size1, ..., sizeN-1]
SmallVector<Type> resultTypes;
resultTypes.push_back(MemRefType::get(
{}, memrefType.getElementType())); // base memref (unranked)
resultTypes.push_back(indexType); // offset
for (unsigned i = 0; i < rank; ++i)
resultTypes.push_back(indexType); // strides
for (unsigned i = 0; i < rank; ++i)
resultTypes.push_back(indexType); // sizes
auto meta = memref::ExtractStridedMetadataOp::create(
rewriter, loc, resultTypes, baseMemref);
if (strides.empty())
strides.append(meta.getStrides().begin(), meta.getStrides().end());
if (!offsetVal)
offsetVal = meta.getOffset();
}
if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
vector::TransferWriteOp>::value) {
AffineMap permMap = xferOp.getPermutationMap();
// Adjust strides according to the permutation map (e.g., for transpose)
adjustStridesForPermutation(permMap, strides);
}
return {strides, offsetVal};
}
// This function compute the vectors of localOffsets for scattered load/stores.
// It is used in the lowering of vector.transfer_read/write to
// load_gather/store_scatter Example:
// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
// %cst {in_bounds = [true, true, true, true]}>} :
// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
//
// %6 = vector.step: vector<4xindex>
// %7 = vector.step: vector<2xindex>
// %8 = vector.step: vector<6xindex>
// %9 = vector.step: vector<32xindex>
// %10 = arith.mul %6, 384
// %11 = arith.mul %7, 192
// %12 = arith.mul %8, 32
// %13 = arith.mul %9, 1
// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
// %22 = arith.add %18, %19
// %23 = arith.add %20, %21
// %local_offsets = arith.add %22, %23
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
// %offsets = memref_offset + orig_offset + local_offsets
static Value computeOffsets(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter, ArrayRef<Value> strides,
Value baseOffset) {
Location loc = xferOp.getLoc();
VectorType vectorType = xferOp.getVectorType();
SmallVector<Value> indices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
ArrayRef<int64_t> vectorShape = vectorType.getShape();
// Create vector.step operations for each dimension
SmallVector<Value> stepVectors;
llvm::map_to_vector(vectorShape, [&](int64_t dim) {
auto stepType = VectorType::get({dim}, rewriter.getIndexType());
auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
stepVectors.push_back(stepOp);
return stepOp;
});
// Multiply step vectors by corresponding strides
size_t memrefRank = strides.size();
size_t vectorRank = vectorShape.size();
SmallVector<Value> strideMultiplied;
for (size_t i = 0; i < vectorRank; ++i) {
size_t memrefDim = memrefRank - vectorRank + i;
Value strideValue = strides[memrefDim];
auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
auto bcastOp =
vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
strideMultiplied.push_back(mulOp);
}
// Shape cast each multiplied vector to add singleton dimensions
SmallVector<Value> shapeCasted;
for (size_t i = 0; i < vectorRank; ++i) {
SmallVector<int64_t> newShape(vectorRank, 1);
newShape[i] = vectorShape[i];
auto newType = VectorType::get(newShape, rewriter.getIndexType());
auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
strideMultiplied[i]);
shapeCasted.push_back(castOp);
}
// Broadcast each shape-casted vector to full vector shape
SmallVector<Value> broadcasted;
auto fullIndexVectorType =
VectorType::get(vectorShape, rewriter.getIndexType());
for (Value shapeCastVal : shapeCasted) {
auto broadcastOp = vector::BroadcastOp::create(
rewriter, loc, fullIndexVectorType, shapeCastVal);
broadcasted.push_back(broadcastOp);
}
// Add all broadcasted vectors together to compute local offsets
Value localOffsets = broadcasted[0];
for (size_t i = 1; i < broadcasted.size(); ++i)
localOffsets =
arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
// Compute base offset from transfer read indices
for (size_t i = 0; i < indices.size(); ++i) {
Value strideVal = strides[i];
Value offsetContrib =
arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
baseOffset =
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
// Broadcast base offset to match vector shape
Value bcastBase = vector::BroadcastOp::create(
rewriter, loc, fullIndexVectorType, baseOffset);
localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
return localOffsets;
}
// Compute the element-wise offsets for vector.gather or vector.scatter ops.
//
// This function linearizes the base offsets of the gather/scatter operation
// and combines them with the per-element indices to produce a final vector of
// memory offsets.
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
ArrayRef<Value> strides, Value baseOffset) {
Location loc = gatScatOp.getLoc();
SmallVector<Value> offsets = gatScatOp.getOffsets();
for (size_t i = 0; i < offsets.size(); ++i) {
Value offsetContrib =
arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
baseOffset =
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
Value indices = gatScatOp.getIndices();
VectorType vecType = cast<VectorType>(indices.getType());
Value strideVector =
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
.getResult();
Value stridedIndices =
arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
Value baseVector =
vector::BroadcastOp::create(
rewriter, loc,
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
baseOffset)
.getResult();
return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
.getResult();
}
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
vector::GatherOp, vector::ScatterOp>::value>>
// Convert memref to i64 base pointer
static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
Location loc = xferOp.getLoc();
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, xferOp.getBase())
.getResult();
return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
indexPtr)
.getResult();
}
static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
PatternRewriter &rewriter) {
Location loc = readOp.getLoc();
VectorType vectorType = readOp.getVectorType();
ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
if (!memrefType)
return rewriter.notifyMatchFailure(readOp, "Expected memref source");
auto meta = computeMemrefMeta(readOp, rewriter);
if (meta.first.empty())
return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
Value localOffsets =
computeOffsets(readOp, rewriter, meta.first, meta.second);
Value flatMemref = memrefToIndexPtr(readOp, rewriter);
Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
vectorShape);
auto gatherOp = xegpu::LoadGatherOp::create(
rewriter, loc, vectorType, flatMemref, localOffsets, mask,
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
rewriter.replaceOp(readOp, gatherOp.getResult());
return success();
}
static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) {
Location loc = writeOp.getLoc();
VectorType vectorType = writeOp.getVectorType();
ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
if (!memrefType)
return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
auto meta = computeMemrefMeta(writeOp, rewriter);
if (meta.first.empty())
return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
Value localOffsets =
computeOffsets(writeOp, rewriter, meta.first, meta.second);
Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
vectorShape);
xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
localOffsets, mask,
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
rewriter.eraseOp(writeOp);
return success();
}
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
Location loc = readOp.getLoc();
if (failed(transferPreconditions(rewriter, readOp)))
return failure();
// TODO:This check needs to be replaced with proper uArch capability check
auto chip = xegpu::getChipStr(readOp);
if (chip != "pvc" && chip != "bmg") {
// lower to scattered load Op if the target HW doesn't have 2d block load
// support
// TODO: add support for OutOfBound access
if (readOp.hasOutOfBoundsDim())
return failure();
return lowerToScatteredLoadOp(readOp, rewriter);
}
// Perform common data transfer checks.
VectorType vecTy = readOp.getVectorType();
if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
return failure();
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
return rewriter.notifyMatchFailure(
readOp, "Unsupported non-zero padded out-of-bounds read");
AffineMap readMap = readOp.getPermutationMap();
bool isTransposeLoad = !readMap.isMinorIdentity();
Type elementType = vecTy.getElementType();
unsigned minTransposeBitWidth = 32;
if (isTransposeLoad &&
elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
return rewriter.notifyMatchFailure(
readOp, "Unsupported data type for transposition");
// If load is transposed, get the base shape for the tensor descriptor.
SmallVector<int64_t> descShape(vecTy.getShape());
if (isTransposeLoad)
std::reverse(descShape.begin(), descShape.end());
auto descType = xegpu::TensorDescType::get(
descShape, elementType, /*array_length=*/1,
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
readOp.getIndices());
DenseI64ArrayAttr transposeAttr =
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(readOp, loadOp);
return success();
}
};
struct TransferWriteLowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
Location loc = writeOp.getLoc();
if (failed(transferPreconditions(rewriter, writeOp)))
return failure();
// TODO:This check needs to be replaced with proper uArch capability check
auto chip = xegpu::getChipStr(writeOp);
if (chip != "pvc" && chip != "bmg") {
// lower to scattered store Op if the target HW doesn't have 2d block
// store support
// TODO: add support for OutOfBound access
if (writeOp.hasOutOfBoundsDim())
return failure();
return lowerToScatteredStoreOp(writeOp, rewriter);
}
// Perform common data transfer checks.
VectorType vecTy = writeOp.getVectorType();
if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
return failure();
AffineMap map = writeOp.getPermutationMap();
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
writeOp.getIndices());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeOp =
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
return success();
}
};
struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
using OpRewritePattern<vector::GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
PatternRewriter &rewriter) const override {
auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
if (!srcTy)
return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
Location loc = gatherOp.getLoc();
VectorType vectorType = gatherOp.getVectorType();
auto meta = computeMemrefMeta(gatherOp, rewriter);
if (meta.first.empty())
return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
Value localOffsets =
computeOffsets(rewriter, gatherOp, meta.first, meta.second);
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
auto xeGatherOp = xegpu::LoadGatherOp::create(
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
auto selectOp =
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
xeGatherOp.getResult(), gatherOp.getPassThru());
rewriter.replaceOp(gatherOp, selectOp.getResult());
return success();
}
};
struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
using OpRewritePattern<vector::ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
PatternRewriter &rewriter) const override {
auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
if (!srcTy)
return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
Location loc = scatterOp.getLoc();
auto meta = computeMemrefMeta(scatterOp, rewriter);
if (meta.first.empty())
return rewriter.notifyMatchFailure(scatterOp,
"Failed to compute strides");
Value localOffsets =
computeOffsets(rewriter, scatterOp, meta.first, meta.second);
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
flatMemref, localOffsets, scatterOp.getMask(),
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
rewriter.eraseOp(scatterOp);
return success();
}
};
struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
Location loc = loadOp.getLoc();
VectorType vecTy = loadOp.getResult().getType();
if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
return failure();
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto loadNdOp = xegpu::LoadNdOp::create(
rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
return success();
}
};
struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
Location loc = storeOp.getLoc();
TypedValue<VectorType> vector = storeOp.getValueToStore();
VectorType vecTy = vector.getType();
if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
return failure();
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeNdOp =
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(storeOp, storeNdOp);
return success();
}
};
struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
Location loc = contractOp.getLoc();
if (contractOp.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(contractOp,
"Expects add combining kind");
TypedValue<Type> acc = contractOp.getAcc();
VectorType accType = dyn_cast<VectorType>(acc.getType());
if (!accType || accType.getRank() != 2)
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
// Accept only plain 2D data layout.
// VNNI packing is applied to DPAS as a separate lowering step.
TypedValue<VectorType> lhs = contractOp.getLhs();
TypedValue<VectorType> rhs = contractOp.getRhs();
if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
return rewriter.notifyMatchFailure(contractOp,
"Expects lhs and rhs 2D vectors");
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
TypeRange{contractOp.getResultType()},
ValueRange{lhs, rhs, acc});
rewriter.replaceOp(contractOp, dpasOp);
return success();
}
};
struct ConvertVectorToXeGPUPass
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToXeGPUConversionPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
void mlir::populateVectorToXeGPUConversionPatterns(
RewritePatternSet &patterns) {
patterns
.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
patterns.getContext());
}