blob: 75a87f84b3da858d7dcfb2050a206bbe09597a8a [file] [edit]
//===- XeGPUSgToLaneDistribute.cpp - XeGPU SG to Lane Pass ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUSGTOLANEDISTRIBUTE
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
} // namespace xegpu
} // namespace mlir
using namespace mlir;
#define DEBUG_TYPE "xegpu-sg-to-lane-distribute"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
/// Casts the given vector value `v` to the expected vector type `expectedTy`.
static Value castValueTo(ConversionPatternRewriter &rewriter,
TypedValue<VectorType> v, VectorType expectedTy) {
// If the type matches, simply return the value itself.
if (v.getType() == expectedTy)
return v;
// If only shape differs, use shape cast.
if (isa<VectorType>(v.getType()) &&
v.getType().getNumElements() == expectedTy.getNumElements())
return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
// Else create an unrealized cast.
auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
expectedTy, ValueRange{v});
return newOp.getResult(0);
}
/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
/// exactly 1 reduction dimension, it had valid result layout attribute, and
/// result type can be distributed to lanes using the layout.
static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
// If no layout, not valid.
if (!resLayout || !resLayout.isForSubgroup())
return false;
// Scalar result (e.g., vector<32xf32> to f32) is valid.
if (op.getType().isIntOrFloat())
return op.getReductionDims().size() == 1;
VectorType resTy = dyn_cast<VectorType>(op.getType());
if (!resTy)
return false;
// Compute the distributed result vector type based on the layout.
FailureOr<VectorType> resDistTypeOrFailure =
getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
if (failed(resDistTypeOrFailure))
return false;
return op.getReductionDims().size() == 1;
}
/// A vector::MultiDimReductionOp is doing lane-local reduction if each lane
/// is doing its own local reduction. In this case the result layout ensures
/// that result vector is distributed to lanes, i.e. the result vector type is
/// different from the distributed result vector type.
static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
// Must be valid MultiDimReductionOp.
assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
"MultiDimReductionOp");
auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
VectorType resTy = dyn_cast<VectorType>(op.getType());
auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
return resTy != resDistTypeOrFailure.value();
}
/// Given a vector type and its distributed vector type, return the list of
/// dimensions that are distributed.
static SmallVector<int64_t> getDistributedDims(VectorType originalType,
VectorType distributedType) {
assert(originalType.getRank() == distributedType.getRank() &&
"original and distributed vector types must have the same rank");
SmallVector<int64_t> distributedDims;
for (int64_t i = 0; i < originalType.getRank(); ++i) {
if (distributedType.getDimSize(i) != originalType.getDimSize(i))
distributedDims.push_back(i);
}
return distributedDims;
}
/// Distributes a subgroup-level CreateNdDesc op to lane-level CreateNdDesc
/// op. This simply drops the layout attribute from the tensor descriptor type.
struct SgToLaneCreateNdDesc
: public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::TensorDescType resultType = op.getType();
// If no layout, nothing to do.
if (!resultType.getLayout())
return failure();
auto newOp = xegpu::CreateNdDescOp::create(
rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
op->getAttrs());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// Distributes a subgroup-level LoadNd op to lane-level LoadNd op. Output
/// of lane-level LoadNd op is 1D. ShapeCast is added to restore the
/// original rank.
struct SgToLaneLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
// If no layout, nothing to do.
if (!layout)
return failure();
// Check if the layout attached to the tensor descriptor is same as the
// anchor layout. Otherwise, this is a conflict.
if (op.getTensorDescType().getLayout() != layout)
return rewriter.notifyMatchFailure(
op, "conflicting layout attributes on tensor descriptor and anchor");
auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
if (!uArch)
return rewriter.notifyMatchFailure(
op, "xegpu::LoadNdOp require target attribute attached to "
"determine transpose "
"requirement");
auto supportedLaneResultTyOrFailure =
xegpu::getDistributedVectorType(op.getTensorDescType());
auto expectedLaneResultTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
if (failed(supportedLaneResultTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute the lane vector type for LoadNdOp");
if (failed(expectedLaneResultTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute expected lane vector type from lane layout");
auto newOp = xegpu::LoadNdOp::create(
rewriter, op.getLoc(), supportedLaneResultTyOrFailure.value(),
adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr(), /**layout**/ nullptr);
// Set the packed attribute if the layout requires it.
newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
// Set the transpose attribute if the layout requires it.
if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
expectedLaneResultTyOrFailure.value()));
return success();
}
};
/// Distributes a subgroup-level StoreNd op to lane-level StoreNd op. Stored
/// value in lane-level StoreNd op is 1D. ShapeCast is added to cast the
/// incoming value to 1D.
struct SgToLaneStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
// If no layout, nothing to do.
if (!layout)
return failure();
// Check if the layout attached to the tensor descriptor and value layout is
// same as the anchor layout. Otherwise, this is a conflict.
if (op.getTensorDescType().getLayout() != layout)
return rewriter.notifyMatchFailure(
op, "conflicting layout attributes on tensor descriptor and anchor");
auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
if (valueLayout != layout)
return rewriter.notifyMatchFailure(
op, "conflicting layout attributes on value and anchor");
auto supportedLaneValueTyOrFailure =
xegpu::getDistributedVectorType(op.getTensorDescType());
if (failed(supportedLaneValueTyOrFailure))
return rewriter.notifyMatchFailure(
op,
"unable to compute lane vector type for StoreNdOp value from tensor "
"descriptor");
xegpu::StoreNdOp::create(
rewriter, op.getLoc(),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
supportedLaneValueTyOrFailure.value()),
adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
rewriter.eraseOp(op);
return success();
}
};
/// Distributes a subgroup-level Dpas op to lane-level Dpas op. All inpputs
/// and output of lane-level Dpas op are 1D. Necessary casts are added to
/// convert the inputs and output to/from 1D.
struct SgToLaneDpas : public OpConversionPattern<xegpu::DpasOp> {
using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if the op has A, B and CD layouts attached.
auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
if (!layoutA || !layoutB || !layoutCd)
return failure();
auto laneResultTyOrFailure =
xegpu::getDistributedVectorType(op.getType(), layoutCd);
auto laneATypeOrFailure =
xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
auto laneBTypeOrFailure =
xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
auto expectedLaneResultTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
if (failed(laneResultTyOrFailure) || failed(laneATypeOrFailure) ||
failed(laneBTypeOrFailure))
return rewriter.notifyMatchFailure(
op, "failed to calculate supported lane vector types for DpasOp "
"from layouts");
if (failed(expectedLaneResultTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute expected lane vector type for DpasOp from "
"lane layout");
// Validate bit widths match uArch packed format requirements
const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
if (uArch) {
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
uArch->getInstruction(
xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
if (uArchInstruction) {
auto laneAType = laneATypeOrFailure.value();
auto laneBType = laneBTypeOrFailure.value();
// Calculate total packed bit width = element bit width * vector size
unsigned aPackedBitWidth =
laneAType.getElementTypeBitWidth() * laneAType.getNumElements();
unsigned bPackedBitWidth =
laneBType.getElementTypeBitWidth() * laneBType.getNumElements();
unsigned expectedABitSize = uArchInstruction->getPackedFormatBitSizeA();
unsigned expectedBBitSize = uArchInstruction->getPackedFormatBitSizeB();
if (aPackedBitWidth % expectedABitSize != 0)
return rewriter.notifyMatchFailure(
op,
"A operand packed bit width must be a multiple of uArch packed "
"format requirement");
if (bPackedBitWidth % expectedBBitSize != 0)
return rewriter.notifyMatchFailure(
op,
"B operand packed bit width must be a multiple of uArch packed "
"format requirement");
}
}
auto newOp = xegpu::DpasOp::create(
rewriter, op->getLoc(), laneResultTyOrFailure.value(),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
laneATypeOrFailure.value()),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
laneBTypeOrFailure.value()),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
laneResultTyOrFailure.value()),
/** layoutA**/ nullptr,
/** layoutB**/ nullptr, /** layoutCd**/ nullptr);
// Explicitly set the new types to enable correct type materializations.
rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
expectedLaneResultTyOrFailure.value()));
return success();
}
};
/// Distributes elementwise ops to lane-level elementwise ops. This
/// currently handles elementwise ops with single result only.
struct SgToLaneElementWise : public ConversionPattern {
SgToLaneElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only match ops with elementwise trait and single result.
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
return failure();
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultType)
return rewriter.notifyMatchFailure(
op, "operation result is not a vector type");
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
if (!layout || !layout.isForSubgroup())
return rewriter.notifyMatchFailure(
op, "operation result does not have subgroup distribute layout");
auto laneShapeOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
if (failed(laneShapeOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute lane vector type from the layout");
VectorType newResultType = laneShapeOrFailure.value();
OperationState state(op->getLoc(), op->getName());
state.addOperands(operands);
state.addTypes(newResultType);
// Copy all attributes except for DistributeLayoutAttr.
for (auto attr : op->getAttrs()) {
if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
state.addAttribute(attr.getName(), attr.getValue());
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResult(0));
return success();
}
};
/// Distributes a subgroup-level arith ConstantOp to lane-level arith
/// ConstantOp.
struct SgToLaneArithConstant : public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = dyn_cast<VectorType>(op.getType());
if (!resultType)
return failure();
// Only handle dense vector constants
auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense)
return rewriter.notifyMatchFailure(
op, "only dense splat vector constants are supported");
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
if (!layout || !layout.isForSubgroup())
return rewriter.notifyMatchFailure(
op, "operation result does not have subgroup distribute layout");
auto laneShapeOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
if (failed(laneShapeOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute lane vector type from the layout");
VectorType newResultType = laneShapeOrFailure.value();
auto sclarValue = dense.getSplatValue<Attribute>();
auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
newDenseAttr);
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// Distributes a subgroup-level PrefetchNd op to lane-level PrefetchNd op.
struct SgToLanePrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
// If no layout, nothing to do.
if (!layout)
return failure();
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
op.getMixedOffsets(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr(),
/**layout**/ nullptr);
rewriter.eraseOp(op);
return success();
}
};
/// Distributes a subgroup-level LoadGather (xegpu.load) op to lane-level.
///
/// Example 1 (1D, no chunk size):
/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
/// %mask = producer_op : vector<16xi1>
/// %offset = producer_op : vector<16xindex>
/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
/// Distributed to:
/// %mask = producer_op : vector<1xi1>
/// %offset = producer_op : vector<1xindex>
/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
///
/// Example 2 (2D with chunk size, same mask & offset):
/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
/// Distributed to:
/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
///
/// Example 3 (3D with leading unit dims):
/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
/// %mask = producer_op : vector<1x1x16xi1>
/// %offset = producer_op : vector<1x1x16xindex>
/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
/// vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
/// Distributed to:
/// %mask = producer_op : vector<1x1x1xi1>
/// %offset = producer_op : vector<1x1x1xindex>
/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
struct SgToLaneLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
if (!layout)
return failure();
VectorType origResultTy = op.getValueType();
if (!origResultTy)
return failure();
// Check that leading dimensions are unit.
int chunkSize = op.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
ArrayRef<int64_t> shape = origResultTy.getShape();
if (llvm::any_of(
shape.take_front(origResultTy.getRank() - effectiveVecRank),
[](int64_t d) { return d != 1; }))
return rewriter.notifyMatchFailure(
op, "Only unit dimensions allowed for the leading "
"dimensions of the load vector!");
auto distResultTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
if (failed(distResultTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute expected lane vector type from lane layout");
VectorType distResultTy = distResultTyOrFailure.value();
VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
distResultTy.getElementType());
// Flatten offsets and mask to 1D to match the 1D result type.
Value distOffsets = adaptor.getOffsets();
auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
distOffsetsTy.getElementType());
distOffsets = castValueTo(
rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
Value distMask = adaptor.getMask();
auto distMaskTy = cast<VectorType>(distMask.getType());
VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
distMaskTy.getElementType());
distMask =
castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
Value distSource = adaptor.getSource();
auto newOp = xegpu::LoadGatherOp::create(
rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr(), /*layout=*/nullptr);
Value result = newOp->getResult(0);
if (distResultTy1D != distResultTy)
result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
distResultTy);
rewriter.replaceOp(op, result);
return success();
}
};
/// This pattern distributes a subgroup-level vector.reduction op to
/// lane-level. This require shuffling the data across the lanes (using
/// gpu::ShuffleOp) and reducing in stages until all lanes have the final
/// result.
struct SgToLaneVectorReduction
: public OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
// If no layout, nothing to do.
if (!layout || !layout.isForSubgroup())
return failure();
VectorType srcVecType = op.getSourceVectorType();
// Only rank 1 vectors supported.
if (srcVecType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "Only rank 1 reductions can be distributed.");
// Lane layout must have the same rank as the vector.
if (layout.getRank() != srcVecType.getRank())
return rewriter.notifyMatchFailure(
op, "Layout rank does not match vector rank.");
// Get the subgroup size from the layout.
int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
if (!uArch)
return rewriter.notifyMatchFailure(
op, "xegpu::ReductionOp require target attribute attached to "
"determine subgroup size");
// Only subgroup-sized vectors supported.
if (sgSize != uArch->getSubgroupSize() ||
srcVecType.getShape()[0] % sgSize != 0)
return rewriter.notifyMatchFailure(op,
"Invalid layout or reduction vector "
"dimension must match subgroup size.");
if (!op.getType().isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Reduction distribution currently only supports floats and "
"integer types.");
// Get the distributed vector (per lane portion).
Value laneValVec = adaptor.getVector();
// Distribute and reduce across lanes in the subgroup.
Value fullReduce = xegpu::subgroupReduction(
op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
// If there's an accumulator, combine it with the reduced value.
if (adaptor.getAcc())
fullReduce = vector::makeArithReduction(
rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
rewriter.replaceOp(op, fullReduce);
return success();
}
};
/// This pattern distributes a subgroup-level vector.multi_reduction op to
/// lane-level only if the reduction is lane-local. This means that
/// reduction dimension is not distributed to lanes and each lane does its own
/// local reduction.
struct SgToLaneMultiDimReduction
: public OpConversionPattern<vector::MultiDimReductionOp> {
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value result;
ArrayRef<int64_t> reductionDims = op.getReductionDims();
assert(reductionDims.size() == 1 &&
"Expecting single reduction dimension for subgroup multi "
"reduction op");
// For rank > 2, ensure leading dimensions are unit.
VectorType sourceType = op.getSourceVectorType();
int64_t rank = sourceType.getRank();
if (rank > 2) {
ArrayRef<int64_t> shape = sourceType.getShape();
if (llvm::any_of(shape.take_front(rank - 2),
[](int64_t d) { return d != 1; }))
return rewriter.notifyMatchFailure(
op, "only unit leading dimensions are supported for "
"multi_reduction with rank > 2");
}
// Handle scalar result: full reduction of a distributed vector to a
// scalar. First do a local vector reduction, then cross-lane shuffles.
if (op.getType().isIntOrFloat()) {
auto reductionDim = reductionDims[0];
VectorType origSourceType = op.getSourceVectorType();
int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
// Local reduction to scalar, then cross-lane butterfly shuffles.
result =
xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
op.getKind(), reductionDimSize);
// Combine with accumulator if present.
if (adaptor.getAcc())
result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
result, adaptor.getAcc());
} else if (isReductionLaneLocal(op)) {
// For lane-local reduction, lower to a sequence of vector.reduction ops
// over 1D slices extracted from the distributed source vector. This is
// required so we dont have 2D source vectors at xegpu-linearize.
auto reductionDim = reductionDims[0];
result = xegpu::lowerToVectorReductions(
cast<TypedValue<VectorType>>(adaptor.getSource()),
cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
reductionDim, op.getLoc(), rewriter);
} else {
auto reductionDim = reductionDims[0];
VectorType sourceType = op.getSourceVectorType();
int64_t reductionDimSize = sourceType.getShape()[reductionDim];
result = xegpu::lowerCrossLaneReductionToShuffles(
cast<TypedValue<VectorType>>(adaptor.getSource()),
cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
reductionDim, reductionDimSize, op.getLoc(), rewriter);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// Helper to compute distributed coordinates for matrix ops.
/// When not using subgroup_block_io, each lane computes its own
/// coordinates based on the layout and lane ID.
static SmallVector<Value> computeDistributedCoordsForMatrixOp(
ConversionPatternRewriter &rewriter, Location loc,
xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
ValueRange origOffsets) {
Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
/*upperBound=*/mlir::IntegerAttr());
auto maybeCoords =
layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
if (failed(maybeCoords))
return {};
assert(maybeCoords.value().size() == 1 &&
"Expected one set of distributed offsets");
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
getAsOpFoldResult(origOffsets));
return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
}
/// This pattern distributes a subgroup-level LoadMatrix op to lane-level.
struct SgToLaneLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto layout = op.getLayoutAttr();
// If no layout, nothing to do.
if (!layout)
return failure();
VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
if (!sgPayloadTy)
return rewriter.notifyMatchFailure(
op, "the matrix op payload must be a vector type");
auto loc = op.getLoc();
auto offsets = op.getMixedOffsets();
if (offsets.empty())
return rewriter.notifyMatchFailure(op, "the load op must have offsets");
FailureOr<VectorType> distPayloadTyOrFailure =
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
if (failed(distPayloadTyOrFailure))
return rewriter.notifyMatchFailure(
op, "Failed to distribute matrix op payload based on layout.");
SmallVector<Value> offsetsAsValues =
vector::getAsValues(rewriter, loc, offsets);
SmallVector<Value> newCoords = offsetsAsValues;
if (!op.getSubgroupBlockIoAttr()) {
newCoords = computeDistributedCoordsForMatrixOp(
rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
if (newCoords.empty())
return rewriter.notifyMatchFailure(
op, "Failed to compute distributed coordinates.");
}
SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
ShapedType::kDynamic);
DenseI64ArrayAttr newConstOffsetsAttr =
rewriter.getDenseI64ArrayAttr(newConstOffsets);
auto newOp = xegpu::LoadMatrixOp::create(
rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
xegpu::DistributeLayoutAttr{});
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// Distributes a subgroup-level vector.transpose op to lane-level.
struct SgToLaneVectorTranspose
: public OpConversionPattern<vector::TransposeOp> {
using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getTemporaryLayout(op->getOpOperand(0));
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getTemporaryLayout(op->getOpResult(0));
if (!sourceLayout || !resultLayout)
return rewriter.notifyMatchFailure(
op, "the source or result vector of the transpose op lacks layout "
"attribute");
ArrayRef<int64_t> perm = op.getPermutation();
// Result layout must be a transpose of source layout.
if (!resultLayout.isTransposeOf(sourceLayout, perm,
xegpu::LayoutKind::Lane))
return rewriter.notifyMatchFailure(
op, "the source or result vector layouts must be transposes of "
"each other");
FailureOr<VectorType> distributedResultTypeOrFailure =
getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
if (failed(distributedResultTypeOrFailure))
return rewriter.notifyMatchFailure(
op, "Failed to distribute the result vector type in "
"vector::Transpose op");
auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
adaptor.getVector(), perm);
rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
distributedResultTypeOrFailure.value()));
return success();
}
};
/// Distributes a subgroup-level vector.bitcast op to lane-level.
/// Bitcast only impacts the innermost dimension of the source/result vectors.
struct SgToLaneVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getTemporaryLayout(op->getOpResult(0));
if (!resultLayout)
return rewriter.notifyMatchFailure(
op, "result vector of the bitcast op lacks layout attribute");
FailureOr<VectorType> distributedResultTypeOrFailure =
getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
if (failed(distributedResultTypeOrFailure))
return rewriter.notifyMatchFailure(
op, "Failed to distribute the result vector type in "
"vector::BitCast op");
auto newOp = vector::BitCastOp::create(
rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
adaptor.getSource());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
/// to lane-level. Uses `computeDistributedCoords()` to obtain the
/// coordinates each lane owns, then compares each coordinate against the
/// original mask bounds using `arith.cmpi slt`. The per-element boolean
/// results are assembled into the distributed mask vector.
///
/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
/// satisfy `coord[i] < bound[i]`.
///
/// Example (1D):
/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
/// %mask = vector.create_mask %m0 : vector<16xi1>
/// For lane k, computeDistributedCoords gives coord = [k], so:
/// %in_bounds = arith.cmpi slt, %coord, %m0 → i1
/// %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
///
/// Example (2D):
/// layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
/// %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
/// [[r0, c0], [r0, c1]]
/// For each coord: in_bounds = (r < m0) && (c < m1)
/// %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
struct SgToLaneCreateMask : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(op->getOpResult(0));
if (!layout || !layout.isForSubgroup())
return rewriter.notifyMatchFailure(
op, "operation result does not have subgroup distribute layout");
VectorType origType = op.getType();
FailureOr<VectorType> distTypeOrFailure =
getDistVecTypeBasedOnLaneLayout(layout, origType);
if (failed(distTypeOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute lane vector type from the layout");
VectorType distType = distTypeOrFailure.value();
Location loc = op.getLoc();
// Materialize the original mask bounds as Values.
SmallVector<Value> origBounds;
if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
origBounds.append(op.getOperands().begin(), op.getOperands().end());
} else {
auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
for (auto dimSize : dimSizes)
origBounds.push_back(
arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
}
ArrayRef<int64_t> origShape = origType.getShape();
// Use computeDistributedCoords to get the coordinates each WI owns.
Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
/*upperBound=*/mlir::IntegerAttr());
auto maybeCoordsVec =
layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
if (failed(maybeCoordsVec))
return rewriter.notifyMatchFailure(
op, "failed to compute distributed coordinates from layout");
SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
int64_t numElements = distType.getNumElements();
assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
"number of coordinate sets must match number of distributed "
"elements");
// For each element, compare all coordinates against bounds.
Value trueVal =
arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
SmallVector<Value> maskBits;
for (auto &coords : coordsVec) {
Value inBounds = trueVal;
for (size_t i = 0; i < coords.size(); ++i) {
Value cmp = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
}
maskBits.push_back(inBounds);
}
// Build the distributed mask vector.
Value result;
if (numElements == 1) {
result =
vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
} else {
result =
vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// This pattern distributes a subgroup-level StoreMatrix op to lane-level.
struct SgToLaneStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto layout = op.getLayoutAttr();
// If no layout, nothing to do.
if (!layout)
return failure();
VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
if (!sgPayloadTy)
return rewriter.notifyMatchFailure(
op, "the matrix op payload must be a vector type");
auto loc = op.getLoc();
auto offsets = op.getMixedOffsets();
if (offsets.empty())
return rewriter.notifyMatchFailure(op, "the store op must have offsets");
FailureOr<VectorType> distPayloadTyOrFailure =
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
if (failed(distPayloadTyOrFailure))
return rewriter.notifyMatchFailure(
op, "Failed to distribute matrix op payload based on layout.");
SmallVector<Value> offsetsAsValues =
vector::getAsValues(rewriter, loc, offsets);
SmallVector<Value> newCoords = offsetsAsValues;
if (!op.getSubgroupBlockIoAttr()) {
newCoords = computeDistributedCoordsForMatrixOp(
rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
if (newCoords.empty())
return rewriter.notifyMatchFailure(
op, "Failed to compute distributed coordinates.");
}
SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
ShapedType::kDynamic);
DenseI64ArrayAttr newConstOffsetsAttr =
rewriter.getDenseI64ArrayAttr(newConstOffsets);
xegpu::StoreMatrixOp::create(
rewriter, loc, TypeRange{},
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
distPayloadTyOrFailure.value()),
adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
rewriter.eraseOp(op);
return success();
}
};
/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
/// lane-level.
///
/// Example 1 (1D, no chunk size):
/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
/// %mask = producer_op : vector<16xi1>
/// %offset = producer_op : vector<16xindex>
/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
/// memref<256xf16>, vector<16xindex>, vector<16xi1>
/// Distributed to:
/// %mask = producer_op : vector<1xi1>
/// %offset = producer_op : vector<1xindex>
/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
/// memref<256xf16>, vector<1xindex>, vector<1xi1>
///
/// Example 2 (2D with chunk size, same mask & offset):
/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
/// Distributed to:
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
///
/// Example 3 (3D with leading unit dims):
/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
/// %mask = producer_op : vector<1x1x16xi1>
/// %offset = producer_op : vector<1x1x16xindex>
/// xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
/// memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
/// Distributed to:
/// %mask = producer_op : vector<1x1x1xi1>
/// %offset = producer_op : vector<1x1x1xindex>
/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
/// memref<256xf16>, vector<1xindex>, vector<1xi1>
struct SgToLaneStoreScatter
: public OpConversionPattern<xegpu::StoreScatterOp> {
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
if (!layout)
return failure();
VectorType origValueTy = op.getValueType();
if (!origValueTy)
return failure();
// Check that all leading dimensions are unit dimensions.
int chunkSize = op.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
ArrayRef<int64_t> shape = origValueTy.getShape();
if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
[](int64_t d) { return d != 1; }))
return rewriter.notifyMatchFailure(
op, "Only unit dimensions allowed for the leading "
"dimensions of the store vector!");
auto distValueTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
if (failed(distValueTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute expected lane vector type from lane layout");
VectorType distValueTy = distValueTyOrFailure.value();
VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
distValueTy.getElementType());
Value distValue = adaptor.getValue();
if (distValue.getType() != distValueTy1D)
distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
distValueTy1D);
// Flatten offsets and mask to 1D to match the 1D value type.
Value distOffsets = adaptor.getOffsets();
auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
distOffsetsTy.getElementType());
distOffsets = castValueTo(
rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
Value distMask = adaptor.getMask();
auto distMaskTy = cast<VectorType>(distMask.getType());
VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
distMaskTy.getElementType());
distMask =
castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
Value distDest = adaptor.getDest();
xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
distOffsets, distMask, op.getChunkSizeAttr(),
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr(), /*layout=*/nullptr);
rewriter.eraseOp(op);
return success();
}
};
/// Distribute a vector::StepOp to lane-level.
/// The layout must have exactly 1 effective lane dimension.
/// We completely resolve the vector::StepOp by computing the lane_data-sized
/// subranges.
struct SgToLaneVectorStep : public OpConversionPattern<vector::StepOp> {
using OpConversionPattern<vector::StepOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getTemporaryLayout(op->getResult(0));
if (!resultLayout || !resultLayout.isForSubgroup())
return rewriter.notifyMatchFailure(
op, "the result vector of the step op lacks subgroup layout");
auto loc = op.getLoc();
auto stepResultVecTy = op.getResult().getType();
auto laneShapeOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
if (failed(laneShapeOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute lane vector type from the layout");
VectorType newVecTy = laneShapeOrFailure.value();
Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
/*upperBound=*/mlir::IntegerAttr());
auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
rewriter, loc, laneId, stepResultVecTy.getShape());
if (failed(laneDataBlockCoords))
return rewriter.notifyMatchFailure(
op, "failed to compute lane data block coordinates");
auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
newVecTy.getNumElements() / laneDataBlockLength);
SmallVector<Value> stepVals;
// For each lane_data block, reconstruct its sub-range
// from the range of SG-level vector.step.Example: vector.step
// {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
// vector<16xindex>
// Each logical lane holds 4 elements as 2 blocks of 2 elements each.
// The blocks are round-robin distributed, so logical lane id 0
// holds values [0,1, 8,9].
for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
auto laneDataBlockStartCoord = laneDataBlockCoords[0];
stepVals.push_back(laneDataBlockStartCoord);
for (int i = 1; i < laneDataBlockLength; ++i) {
auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
stepVals.push_back(arith::AddIOp::create(
rewriter, loc, laneDataBlockStartCoord, offset));
}
}
assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
"Expecting the number of step values to match the number of "
"elements in the vector");
auto stepOpVal =
vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
rewriter.replaceOp(op, stepOpVal);
return success();
}
};
/// Distributes a subgroup-level vector.extract op to lane-level. Only
/// handles sub-vector extraction (result is VectorType, not scalar).
struct SgToLaneVectorExtract : public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle vector results (not scalar extraction).
auto resultType = dyn_cast<VectorType>(op.getType());
if (!resultType)
return rewriter.notifyMatchFailure(op, "scalar extract not supported");
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(op->getOpResult(0));
if (!layout || !layout.isForSubgroup())
return failure();
// This implementation assumes distribution only happens on the innermost
// dimension. Verify that lane_layout[0...n-2] are all unit.
auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
[](int64_t v) { return v != 1; }))
return rewriter.notifyMatchFailure(
op, "only innermost dimension distribution is supported for "
"vector.extract");
auto newOp = vector::ExtractOp::create(
rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// This pattern distributes a subgroup-level ShapeCast op to lane-level.
struct SgToLaneVectorShapeCast
: public OpConversionPattern<vector::ShapeCastOp> {
using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getTemporaryLayout(op->getOpResult(0));
if (!resultLayout || !resultLayout.isForSubgroup())
return rewriter.notifyMatchFailure(
op, "the result vector of the shape_cast op lacks subgroup layout");
auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
resultLayout, op.getResultVectorType());
if (failed(resultDistTypeOrFailure))
return rewriter.notifyMatchFailure(
op, "failed to get distributed vector type for result");
Value source = adaptor.getSource();
auto newShapeCast = vector::ShapeCastOp::create(
rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
rewriter.replaceOp(op, newShapeCast);
return success();
}
};
/// Distributes a subgroup-level vector.extract_strided_slice op to
/// lane-level. If the result is distributed, the offsets and sizes are
/// adjusted to match the distributed types.
struct SgToLaneVectorExtractStridedSlice
: public OpConversionPattern<vector::ExtractStridedSliceOp> {
using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getTemporaryLayout(op->getOpResult(0));
if (!resultLayout || !resultLayout.isForSubgroup())
return failure();
VectorType resultType = op.getType();
auto distResultTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
if (failed(distResultTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute distributed vector type from lane layout");
VectorType distResultTy = *distResultTyOrFailure;
SmallVector<int64_t> distributedDims =
getDistributedDims(resultType, distResultTy);
// Collect updated sizes, offsets, strides. Pad to full source rank.
int64_t sourceRank = op.getSourceVectorType().getRank();
SmallVector<Attribute> updatedSizes =
llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
op.getOffsets(), [](Attribute attr) { return attr; });
SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
op.getStrides(), [](Attribute attr) { return attr; });
for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
updatedSizes.push_back(
rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
}
// If the result is distributed, adjust offsets and sizes in the
// distributed dimension.
if (!distributedDims.empty()) {
if (distributedDims.size() != 1)
return rewriter.notifyMatchFailure(
op, "only single dimension distribution is supported");
int64_t distDim = distributedDims[0];
const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
if (!uArch)
return rewriter.notifyMatchFailure(
op, "target attribute required to determine subgroup size");
int subgroupSize = uArch->getSubgroupSize();
auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
return rewriter.notifyMatchFailure(
op, "source of extract_strided_slice lacks distribution layout");
int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
auto laneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
// Effective subgroup size needs to be adjusted if laneLayout along
// the distributed dimension is smaller than subgroup size.
if (laneLayout[distDim] < subgroupSize &&
subgroupSize % laneLayout[distDim] == 0)
subgroupSize = laneLayout[distDim];
if (sourceDistrDimSize % subgroupSize != 0)
return rewriter.notifyMatchFailure(
op, "source size along distributed dim is not a multiple of "
"subgroup size");
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
// Only check lane_data for the distributed dimension. Non-distributed
// dimensions may have non-unit lane_data (e.g., packed layouts).
if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
sourceLaneData[distDim] != 1)
return rewriter.notifyMatchFailure(
op, "expecting unit lane data along the distributed dimension");
int64_t distrDimOffset =
cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
if (distrDimOffset % subgroupSize != 0)
return rewriter.notifyMatchFailure(
op, "offset along distributed dim is not a multiple of "
"subgroup size");
// Adjust sizes and offsets for the distributed dimension.
updatedSizes[distDim] =
rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
updatedOffsets[distDim] =
rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
}
auto newOp = vector::ExtractStridedSliceOp::create(
rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
ArrayAttr::get(rewriter.getContext(), updatedSizes),
ArrayAttr::get(rewriter.getContext(), updatedStrides));
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// This pattern distributes a subgroup-level `vector.broadcast` op to
/// lane-level. The pattern supports three cases:
///
/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
/// vector must have a slice layout of the result. If the distributed source
/// and target vector types are identical, this lowers to a no-op; otherwise,
/// it remains a broadcast but operates on distributed vectors.
///
/// 2) Broadcast a same-rank vector with identical layouts for source and
/// target: The source vector must have unit dimensions, and lane_data must
/// be unit size for those unit dims. This always lowers to a no-op.
///
/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
/// from scalar to distributed result type.
///
/// Example 1 (low-rank to high-rank broadcast):
/// ```
/// %0 = "some_op"() {layout_result_0 =
/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
/// dims = [0]>} : () -> vector<16xf16>
/// %1 = vector.broadcast %0 {layout_result_0 =
/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
/// : vector<16xf16> to vector<16x16xf16>
/// ```
/// is distributed to:
/// ```
/// %0 = "some_op"() : () -> vector<1xf16>
/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
/// ```
///
/// Example 2 (same-rank broadcast, no-op):
/// ```
/// %0 = "some_op"() {layout_result_0 =
/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
/// : () -> vector<16x1xf16>
/// %1 = vector.broadcast %0 {layout_result_0 =
/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
/// : vector<16x1xf16> to vector<16x16xf16>
/// ```
/// is distributed to (no-op, source already matches distributed result type):
/// ```
/// %0 = "some_op"() : () -> vector<16x1xf16>
/// // broadcast is eliminated, %0 is used directly
/// ```
///
/// Example 3 (scalar to vector broadcast):
/// ```
/// %0 = "some_op"() : () -> f16
/// %1 = vector.broadcast %0 {layout_result_0 =
/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
/// : f16 to vector<16x16xf16>
/// ```
/// is distributed to:
/// ```
/// %0 = "some_op"() : f16
/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
/// ```
struct SgToLaneBroadcast : public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
if (!resultLayout || !resultLayout.isForSubgroup())
return rewriter.notifyMatchFailure(
op, "result does not have subgroup distribute layout");
VectorType destType = op.getResultVectorType();
VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getTemporaryLayout(op->getOpOperand(0));
if (sourceType) {
int64_t rankDiff = destType.getRank() - sourceType.getRank();
if (rankDiff > 0) {
// Case 1: Low-rank to high-rank broadcast.
if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
op.emitWarning(
"broadcast source layout must be a slice of result layout");
} else if (rankDiff == 0) {
// Case 2: Same-rank broadcast.
auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
broadcastUnitDimsSet.end());
assert(sourceLayout.isEqualTo(
sourceLayout.setUnitDimData(broadcastUnitDims)) &&
"The sg_data for unit dimensions should be set as 1");
sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
}
} else {
// Case 3: Scalar to vector broadcast.
if (sourceLayout)
return rewriter.notifyMatchFailure(
op, "broadcast from scalar must not have a layout attribute");
}
auto destDistType =
xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
if (failed(destDistType))
return rewriter.notifyMatchFailure(
op, "failed to distribute the result vector type");
Value source = adaptor.getSource();
// If the adapted source already matches the dest dist type, it's a no-op.
if (source.getType() == destDistType.value()) {
rewriter.replaceOp(op, source);
return success();
}
auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
destDistType.value(), source);
rewriter.replaceOp(op, newOp);
return success();
}
};
/// Distributes a subgroup-level vector.insert_strided_slice op to
/// lane-level. If the dest is distributed, the offsets are adjusted to
/// match the distributed types.
struct SgToLaneVectorInsertStridedSlice
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getTemporaryLayout(op->getOpResult(0));
if (!resultLayout || !resultLayout.isForSubgroup())
return failure();
VectorType destType = op.getDestVectorType();
auto distDestTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
if (failed(distDestTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute distributed vector type from lane layout");
VectorType distDestTy = *distDestTyOrFailure;
SmallVector<int64_t> destDistributedDims =
getDistributedDims(destType, distDestTy);
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
op.getOffsets(), [](Attribute attr) { return attr; });
if (!destDistributedDims.empty()) {
if (destDistributedDims.size() != 1)
return rewriter.notifyMatchFailure(
op, "only single dimension distribution is supported");
int64_t destDistDim = destDistributedDims[0];
const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
if (!uArch)
return rewriter.notifyMatchFailure(
op, "target attribute required to determine subgroup size");
int subgroupSize = uArch->getSubgroupSize();
VectorType srcType = op.getSourceVectorType();
// The distributed dim must be in the last k (source rank) dims of dest.
int64_t sourceDistDim =
destDistDim - (destType.getRank() - srcType.getRank());
if (sourceDistDim < 0)
return rewriter.notifyMatchFailure(
op, "distributed dimension must be in the last k dims of dest");
auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
if (!destLayout || !sourceLayout ||
destLayout.getEffectiveLaneLayoutAsInt().empty() ||
sourceLayout.getEffectiveLaneLayoutAsInt().empty())
return rewriter.notifyMatchFailure(
op, "source or dest of insert_strided_slice lacks distribution "
"layout");
auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
// Only check lane_data for the distributed dimension. Non-distributed
// dimensions may have non-unit lane_data (e.g., packed layouts).
if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
destLaneData[destDistDim] != 1) ||
(sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
sourceLaneData[sourceDistDim] != 1))
return rewriter.notifyMatchFailure(
op, "expecting unit lane data along the distributed dimension");
int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
if (srcDistrDimSize % subgroupSize != 0)
return rewriter.notifyMatchFailure(
op, "source distributed dim size is not a multiple of "
"subgroup size");
int64_t destDistrDimOffset =
cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
if (destDistrDimOffset % subgroupSize != 0)
return rewriter.notifyMatchFailure(
op, "offset along distributed dim is not a multiple of "
"subgroup size");
// Adjust offset for the distributed dimension.
updatedOffsets[destDistDim] =
rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
}
auto newOp = vector::InsertStridedSliceOp::create(
rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
adaptor.getDest(),
ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// Distributes a subgroup-level vector.insert op to lane-level. Only
/// handles sub-vector insertion (value to store is VectorType, not scalar).
struct SgToLaneVectorInsert : public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle vector value-to-store (not scalar insertion).
auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
if (!valueType)
return rewriter.notifyMatchFailure(op, "scalar insert not supported");
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(op->getOpResult(0));
if (!layout || !layout.isForSubgroup())
return failure();
// verify that the outer k dimensions (for offsets)
// don't have non-unit lane_layout.
auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
[](int64_t v) { return v != 1; }))
return rewriter.notifyMatchFailure(
op, "only innermost dimension distribution is supported for "
"vector.insert");
auto newOp = vector::InsertOp::create(
rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
op.getMixedPosition());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// Redistributes `src` for a `convert_layout` that changes only the
/// `lane_layout` along the outer (distributed) dimension, shrinking it from
/// `currentLaneNum` to `targetLaneNum` lanes (a partial-subgroup
/// distribution). Because the data is no longer replicated across all lanes,
/// each surviving lane must gather the values that previously lived in the
/// lanes that are dropped. The values are gathered with `gpu.shuffle` and
/// concatenated with the lane-local data using `vector.shuffle`, which doubles
/// the distributed outer dimension when the lane count is halved.
///
/// Only halving the lane count (a factor of two) is currently supported.
/// Returns the redistributed value on success, or failure if `src` cannot be
/// shuffled (e.g. it is not a rank-2 vector or its bit width is not a multiple
/// of 32).
static FailureOr<Value>
shuffleDataAsLaneLayoutChange(ConversionPatternRewriter &rewriter, Location loc,
Value src, int64_t currentLaneNum,
int64_t targetLaneNum) {
VectorType srcTy = dyn_cast<VectorType>(src.getType());
if (!srcTy || srcTy.getRank() != 2)
return failure();
// Only halving the lane count (factor of two) is supported for now.
if (targetLaneNum <= 0 || currentLaneNum != targetLaneNum * 2)
return failure();
// gpu.shuffle operates on i32, so the data must be a multiple of 32 bits.
int64_t vectorBitWidth =
srcTy.getNumElements() * srcTy.getElementTypeBitWidth();
if (vectorBitWidth % 32 != 0)
return failure();
// A vector cannot be shuffled across lanes directly:
// -- cast the source to a 1D vector of i32
// -- create a temp 1D vector of i32 initialized to zero
// -- for each i32 element:
// ---- extract it from the source bundle
// ---- gpu.shuffle to gather the value from the partner lane
// ---- insert it into the temp bundle
// -- cast the temp back to the source vector type
// -- vector.shuffle the source and temp to concatenate along the outer dim
Type shuffleElemTy = rewriter.getI32Type();
int64_t numShuffles = vectorBitWidth / 32;
VectorType shuffleBundleTy = VectorType::get({numShuffles}, shuffleElemTy);
// Initialize temp to zero.
Value temp = arith::ConstantOp::create(
rewriter, loc,
DenseElementsAttr::get(shuffleBundleTy,
IntegerAttr::get(shuffleElemTy, 0)));
VectorType flatSrcTy =
VectorType::get({srcTy.getNumElements()}, srcTy.getElementType());
Value flatSrc = vector::ShapeCastOp::create(rewriter, loc, flatSrcTy, src);
Value shuffleBundle =
vector::BitCastOp::create(rewriter, loc, shuffleBundleTy, flatSrc);
for (int64_t i = 0; i < numShuffles; i++) {
Value shuffleElem =
vector::ExtractOp::create(rewriter, loc, shuffleBundle, i);
shuffleElem = gpu::ShuffleOp::create(rewriter, loc, shuffleElem, 0,
targetLaneNum, gpu::ShuffleMode::UP)
.getResult(0);
temp = vector::InsertOp::create(rewriter, loc, shuffleElem, temp, i);
}
temp = vector::BitCastOp::create(rewriter, loc, flatSrcTy, temp);
temp = vector::ShapeCastOp::create(rewriter, loc, srcTy, temp);
// Concatenate the lane-local and gathered data along the outer dimension.
SmallVector<int64_t> indices(srcTy.getShape()[0] * 2);
std::iota(indices.begin(), indices.end(), 0);
Value res = vector::ShuffleOp::create(rewriter, loc, src, temp, indices);
return res;
}
/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
struct SgToLaneConvertLayout
: public OpConversionPattern<xegpu::ConvertLayoutOp> {
using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputLayout = op.getInputLayoutAttr();
auto targetLayout = op.getTargetLayoutAttr();
Type valType = op.getResult().getType();
if (valType.isIntOrFloat()) {
rewriter.replaceOp(op, op.getSource());
return success();
}
auto resShape = cast<VectorType>(valType).getShape();
SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
// Equivalent layouts: the convert_layout is a no-op and folds to its
// source.
if (inputLayout.isCompatibleWith(targetLayout, resShapeVec,
xegpu::LayoutKind::Lane)) {
rewriter.replaceOp(op, adaptor.getSource());
return success();
}
// Handle the special case where the conversion redistributes a value
// across a fraction of the subgroup: the lane_layout shrinks along the
// outer (distributed) dimension while lane_data stays the same. Only a
// pure outer-dimension lane_layout change is supported, so the inner
// lane_layout must be unit (making the outer dim the only distributed one)
// and the outer lane_layout must be genuinely distributed (> 1), which
// also rules out the degenerate [1, 1] layout.
if (inputLayout.getEffectiveOrderAsInt() ==
targetLayout.getEffectiveOrderAsInt() &&
inputLayout.getRank() == 2 && targetLayout.getRank() == 2) {
auto laneLayout = inputLayout.getEffectiveLaneLayoutAsInt();
auto targetLaneLayout = targetLayout.getEffectiveLaneLayoutAsInt();
auto laneData = inputLayout.getEffectiveLaneDataAsInt();
auto targetLaneData = targetLayout.getEffectiveLaneDataAsInt();
if (laneLayout.size() == 2 && targetLaneLayout.size() == 2 &&
laneData == targetLaneData && laneLayout[1] == 1 &&
targetLaneLayout[1] == 1 && laneLayout[0] > 1 &&
laneLayout[0] != targetLaneLayout[0]) {
FailureOr<Value> res = shuffleDataAsLaneLayoutChange(
rewriter, op.getLoc(), adaptor.getSource(), laneLayout[0],
targetLaneLayout[0]);
if (succeeded(res)) {
rewriter.replaceOp(op, *res);
return success();
}
}
}
return rewriter.notifyMatchFailure(
op, "lowering incompatible convert_layout not yet supported");
}
};
// Trivially distribute `vector.interleave`
struct SgToLaneVectorInterleave
: public OpConversionPattern<vector::InterleaveOp> {
using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InterleaveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = vector::InterleaveOp::create(
rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
// Trivially distribute `vector.deinterleave`
struct SgToLaneVectorDeinterleave
: public OpConversionPattern<vector::DeinterleaveOp> {
using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::DeinterleaveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = vector::DeinterleaveOp::create(rewriter, op.getLoc(),
adaptor.getSource());
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
struct SgToLaneDpasMx : public OpConversionPattern<xegpu::DpasMxOp> {
using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::DpasMxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
if (!uArch)
return failure();
if (!uArch->isSupportedInstruction(
xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc))
return rewriter.notifyMatchFailure(
op, "target uArch does not support scaled subgroup mma");
// Check if the op has A, B and CD layouts attached.
auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
if (!layoutA || !layoutB || !layoutCd)
return rewriter.notifyMatchFailure(
op, "missing required layout attributes for DpasMxOp distribution");
// Retrieve expected types, according to anchor layouts.
auto expected1DTypeResult =
xegpu::getDistributedVectorType(op.getType(), layoutCd);
auto expected1DTypeA =
xegpu::getDistributedVectorType(op.getA().getType(), layoutA);
auto expected1DTypeB =
xegpu::getDistributedVectorType(op.getB().getType(), layoutB);
VectorType expected1DTypeScaleA, expected1DTypeScaleB;
if (op.getScaleA()) {
auto layoutScaleA = cast<xegpu::LayoutAttr>(op.getLayoutAScaleAttr());
auto expected1DTypeScaleAOrFailure = xegpu::getDistributedVectorType(
cast<VectorType>(op.getScaleA().getType()), layoutScaleA);
if (failed(expected1DTypeScaleAOrFailure))
return rewriter.notifyMatchFailure(
op, "failed to calculate expected 1D vector type for scale A");
expected1DTypeScaleA = expected1DTypeScaleAOrFailure.value();
}
if (op.getScaleB()) {
auto layoutScaleB = cast<xegpu::LayoutAttr>(op.getLayoutBScaleAttr());
auto expected1DTypeScaleBOrFailure = xegpu::getDistributedVectorType(
cast<VectorType>(op.getScaleB().getType()), layoutScaleB);
if (failed(expected1DTypeScaleBOrFailure))
return rewriter.notifyMatchFailure(
op, "failed to calculate expected 1D vector type for scale B");
expected1DTypeScaleB = expected1DTypeScaleBOrFailure.value();
}
auto expectedNDTypeResult =
xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
if (failed(expected1DTypeResult) || failed(expected1DTypeA) ||
failed(expected1DTypeB))
return rewriter.notifyMatchFailure(
op,
"failed to calculate supported workitem 1D vector types for DpasOp "
"from layouts");
if (failed(expectedNDTypeResult))
return rewriter.notifyMatchFailure(
op, "unable to compute expected workitem vector type for DpasOp from "
"lane layout");
// Validate bit widths match uArch packed format requirements
const auto *uArchInstruction = dyn_cast<
xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(uArch->getInstruction(
xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
assert(uArchInstruction);
auto wiAType = expected1DTypeA.value();
auto wiBType = expected1DTypeB.value();
// Calculate total packed bit width = element bit width * vector size
unsigned aPackedBitWidth =
wiAType.getElementTypeBitWidth() * wiAType.getNumElements();
unsigned bPackedBitWidth =
wiBType.getElementTypeBitWidth() * wiBType.getNumElements();
if (aPackedBitWidth % uArchInstruction->getPackedFormatBitSizeA())
return rewriter.notifyMatchFailure(
op, "A operand packed bit width must be a multiple of uArch packed "
"format requirement");
if (bPackedBitWidth % uArchInstruction->getPackedFormatBitSizeB())
return rewriter.notifyMatchFailure(
op, "B operand packed bit width must be a multiple of uArch packed "
"format requirement");
auto newOp = xegpu::DpasMxOp::create(
rewriter, op->getLoc(), expected1DTypeResult.value(),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getA()),
expected1DTypeA.value()),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getB()),
expected1DTypeB.value()),
op.getAcc()
? castValueTo(rewriter,
cast<TypedValue<VectorType>>(adaptor.getAcc()),
expected1DTypeResult.value())
: nullptr,
op.getScaleA()
? castValueTo(rewriter,
cast<TypedValue<VectorType>>(adaptor.getScaleA()),
expected1DTypeScaleA)
: nullptr,
op.getScaleB()
? castValueTo(rewriter,
cast<TypedValue<VectorType>>(adaptor.getScaleB()),
expected1DTypeScaleB)
: nullptr,
/** layoutA**/ nullptr,
/** layoutB**/ nullptr, /** layoutCd**/ nullptr,
/** layoutAScale**/ nullptr, /** layoutBScale**/ nullptr);
// Explicitly set the new types to enable correct type materializations.
rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
expectedNDTypeResult.value()));
return success();
}
};
struct XeGPUSgToLaneDistributePass
: public xegpu::impl::XeGPUSgToLaneDistributeBase<
XeGPUSgToLaneDistributePass> {
void runOnOperation() override;
};
} // namespace
void XeGPUSgToLaneDistributePass::runOnOperation() {
// Recover temporary operand layouts for usage in patterns.
Operation *root = getOperation();
if (!xegpu::recoverTemporaryLayouts(root)) {
signalPassFailure();
return;
}
// Collect existing UnrealizedConversionCastOps. These must be preserved.
llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
root->walk(
[&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
// Perform a structural type conversion to convert structural ops to have WI
// types. This will insert UnrealizedConversionCastOps to make the IR
// valid.
{
ConversionTarget target(getContext());
TypeConverter typeConverter;
RewritePatternSet patterns(&getContext());
// Source (N:1) and target (1:1) materializations using
// UnrealizedConversionCastOp.
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
.getResult(0);
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
xegpu::populateXeGPUSgToLaneDistributeTypeConversions(typeConverter, root);
scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
patterns, target);
xegpu::populateXeGPUSgToLaneDistributeTypeConversionAndLegality(
typeConverter, patterns, target, root);
target.addLegalOp<UnrealizedConversionCastOp>();
(void)applyPartialConversion(root, target, std::move(patterns));
}
// Fold cancelling cast chains and erase dead casts.
xegpu::cleanupUnrealizedConversionCasts(root, existingCasts);
xegpu::removeTemporaryLayoutAttrs(getOperation());
}
void xegpu::populateXeGPUSgToLaneDistributeTypeConversions(
TypeConverter &typeConverter, Operation *topLevelOp) {
// Pass through any type by default; more specific conversions registered
// below override this for TensorDescType and (distributing) VectorType.
typeConverter.addConversion([](Type type) -> Type { return type; });
// For TensorDescType, drop the layout attribute if any.
typeConverter.addConversion([](TensorDescType type) -> Type {
if (type.getLayoutAttr()) {
return type.dropLayouts();
}
return type;
});
// For VectorType, distribute based on the lane layout (1:1 shape-changing
// conversion). Uses xegpu::addVectorTypeConversion with a pre-computed
// map for SCF loop block args (see precomputeLoopBlockArgTypes for the
// rationale).
auto getSubShapeAndCount = [](VectorType vecTy,
xegpu::DistributeLayoutAttr layout)
-> std::pair<SmallVector<int64_t>, int> {
auto distTyOrFailure = getDistVecTypeBasedOnLaneLayout(layout, vecTy);
if (failed(distTyOrFailure))
return {{}, 0};
return {SmallVector<int64_t>(distTyOrFailure->getShape()), 1};
};
auto loopArgTypes =
xegpu::precomputeLoopBlockArgTypes(topLevelOp, getSubShapeAndCount);
xegpu::addVectorTypeConversion(typeConverter, getSubShapeAndCount,
std::move(loopArgTypes));
}
void xegpu::populateXeGPUSgToLaneDistributeTypeConversionAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, Operation *topLevelOp) {
populateXeGPUSgToLaneDistributeTypeConversions(typeConverter, topLevelOp);
// CreateNdDescOp is legal only if its result type has no layout attribute.
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
[&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
// Any anchor XeGPU op is legal only if it has no anchor layout.
target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
if (isa<xegpu::ConvertLayoutOp>(op))
return false;
auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
if (!anchorOp)
return true;
return !anchorOp.getAnchorLayout();
});
// Arith constants are legal only if they have no temporary layout attribute.
target.addDynamicallyLegalOp<arith::ConstantOp>(
[=](arith::ConstantOp op) -> bool {
// If the result type is not a vector, it's legal.
if (!isa<VectorType>(op.getResult().getType()))
return true;
return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
});
// In math and arith dialects, only handle elementwise ops with a single
// result and with a result layout attribute.
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Only handle elementwise mappable ops
if (!OpTrait::hasElementwiseMappableTraits(op))
return true;
// Only handle ops with single vector result
if (op->getNumResults() != 1)
return true;
VectorType resultType =
dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultType)
return true;
// Check if all operands are vectors of the same shape
for (Value operand : op->getOperands()) {
VectorType operandType = dyn_cast<VectorType>(operand.getType());
if (!operandType || operandType.getShape() != resultType.getShape()) {
return true;
}
}
return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
});
// vector::ReductionOp is legal only if its source has no distribute layout
// attribute.
target.addDynamicallyLegalOp<vector::ReductionOp>(
[=](vector::ReductionOp op) -> bool {
auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
return !layout;
});
// vector::MultiDimReductionOp op legality.
target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
[=](vector::MultiDimReductionOp op) -> bool {
return !isValidSubgroupMultiReductionOp(op);
});
target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
vector::TransposeOp, vector::BitCastOp,
vector::ShapeCastOp, vector::StepOp,
vector::BroadcastOp>([=](Operation *op) -> bool {
return !xegpu::getTemporaryLayout(op->getOpResult(0));
});
target.addDynamicallyLegalOp<vector::ExtractOp>(
[=](vector::ExtractOp op) -> bool {
if (!isa<VectorType>(op.getType()))
return true;
return !xegpu::getTemporaryLayout(op->getOpResult(0));
});
target.addDynamicallyLegalOp<vector::InsertOp>(
[=](vector::InsertOp op) -> bool {
return !xegpu::getTemporaryLayout(op->getOpResult(0));
});
target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
[=](vector::ExtractStridedSliceOp op) -> bool {
return !xegpu::getTemporaryLayout(op->getOpResult(0));
});
target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
[=](vector::InsertStridedSliceOp op) -> bool {
return !xegpu::getTemporaryLayout(op->getOpResult(0));
});
target.addDynamicallyLegalOp<vector::InterleaveOp, vector::DeinterleaveOp>(
[=](Operation *op) -> bool {
return !xegpu::getTemporaryLayout(op->getOpResult(0));
});
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<
SgToLaneCreateNdDesc, SgToLaneLoadNd, SgToLaneStoreNd, SgToLaneDpas,
SgToLaneElementWise, SgToLaneArithConstant, SgToLanePrefetchNd,
SgToLaneLoadGather, SgToLaneStoreScatter, SgToLaneVectorReduction,
SgToLaneMultiDimReduction, SgToLaneVectorExtract, SgToLaneVectorInsert,
SgToLaneVectorExtractStridedSlice, SgToLaneVectorInsertStridedSlice,
SgToLaneLoadMatrix, SgToLaneStoreMatrix, SgToLaneConvertLayout,
SgToLaneVectorTranspose, SgToLaneVectorBitcast, SgToLaneVectorStep,
SgToLaneVectorShapeCast, SgToLaneBroadcast,
SgToLaneCreateMask<vector::CreateMaskOp>,
SgToLaneCreateMask<vector::ConstantMaskOp>, SgToLaneVectorDeinterleave,
SgToLaneVectorInterleave, SgToLaneDpasMx>(typeConverter,
patterns.getContext());
}