| //===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
| |
| #include "mlir/Dialect/Affine/Utils.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| #include "mlir/Dialect/Index/IR/IndexOps.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
| #include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h" |
| #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include <optional> |
| |
| namespace mlir { |
| namespace xegpu { |
| #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE |
| #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
| } // namespace xegpu |
| } // namespace mlir |
| |
| using namespace mlir; |
| |
| namespace { |
| |
| // Retrieve the RangeAttr if it is specified. |
| static xegpu::RangeAttr getRangeSpecAttr(Operation *op) { |
| Operation *parent = op->getParentOfType<scf::IfOp>(); |
| while (parent) { |
| if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>( |
| parent->getAttr("sg_id_range"))) |
| return attr; |
| parent = parent->getParentOfType<scf::IfOp>(); |
| } |
| return {}; |
| } |
| |
| static std::pair<SmallVector<int64_t>, int> |
| getSgShapeAndCount(ArrayRef<int64_t> shape, |
| xegpu::DistributeLayoutAttr layout) { |
| int count = 1; |
| SmallVector<int64_t> sgShape(shape); |
| if (layout && layout.isForWorkgroup()) { |
| SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt(); |
| if (!layout.getEffectiveSgDataAsInt().empty()) |
| sgShape = layout.getEffectiveSgDataAsInt(); |
| else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) |
| sgShape = *maybeDerivedSgData; |
| SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape); |
| // Clamp distUnit to the original shape to handle cases where data is |
| // shared among subgroups, which may cause distUnit to exceed the original |
| // shape. |
| for (size_t i = 0; i < distUnit.size(); ++i) |
| distUnit[i] = std::min(shape[i], distUnit[i]); |
| count = computeProduct(shape) / computeProduct(distUnit); |
| } |
| return std::make_pair(sgShape, count); |
| } |
| |
| /// Utility helper for deriving a list of offsets for each sub-TensorDescs |
| /// or sub-MemDescs to be accessed by current subgroup (sgId) based on the |
| /// associated distribute layout attribute, the shape, subgroup id and the |
| /// original offsets of the op |
| template < |
| typename OpType, |
| typename = std::enable_if_t<llvm::is_one_of< |
| OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp, |
| xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>> |
| static LogicalResult |
| genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, |
| SmallVector<SmallVector<OpFoldResult>> &offsetsList) { |
| Location loc = op.getLoc(); |
| SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets(); |
| // not applicable to ops without offsets operands. |
| if (origOffsets.empty()) |
| return failure(); |
| |
| // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr() |
| xegpu::DistributeLayoutAttr layout; |
| if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> || |
| std::is_same_v<OpType, xegpu::StoreMatrixOp>) { |
| layout = op.getLayoutAttr(); |
| } else { |
| layout = op.getDescLayoutAttr(); |
| } |
| |
| // not applicable to ops without workgroup layout attributes |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| Value sgId = |
| gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); |
| |
| // verify and adjust the sgId if the range specifier is present |
| xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op); |
| if (sgIdRange) { |
| int64_t startOfRange = sgIdRange.getStart().getInt(); |
| int64_t endOfRange = sgIdRange.getEnd().getInt(); |
| // verify the RangeAttr against the layout attribute |
| if (layout.getNumSubgroups() != endOfRange - startOfRange) |
| return rewriter.notifyMatchFailure( |
| op, "sg_layout size must match the sg_id_range"); |
| // adjust the sgId if necessary |
| if (startOfRange > 0) { |
| Value startOfRangeVal = |
| arith::ConstantIndexOp::create(rewriter, loc, startOfRange); |
| sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal); |
| } |
| } |
| |
| // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory |
| // descriptors to be accessed, based on the layout information. |
| ArrayRef<int64_t> wgShape = op.getDataShape(); |
| auto maybeDescOffsets = |
| layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); |
| if (failed(maybeDescOffsets)) |
| return failure(); |
| |
| // Compute the final global offsets for each accessed sub-tensor |
| // or sub-memory descriptor. |
| for (const auto &sgOffsets : *maybeDescOffsets) { |
| SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned( |
| rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); |
| offsetsList.push_back(std::move(newOffsets)); |
| } |
| |
| // callback(offsetsList); |
| return success(); |
| } |
| |
| /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor |
| /// from a workgroup descriptor. It replaces the offsets and sizes with |
| /// appropriate values for the subgroup. |
| /// It uses round-robin assignment to distribute the work to the subgroups. |
| /// Following create_nd_desc operation:, |
| /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32> |
| /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], |
| /// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> |
| /// is converted to 9 subgroup level operations based on the sg_layout & |
| /// sg_data: |
| /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> -> |
| /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], |
| /// lane_data = [1, 1]>> |
| /// |
| /// The sg_layout and sg_data attributes are dropped after the pass as they are |
| /// no longer needed. |
| /// |
| /// 24x24 matrix distribution example: |
| /// sg_layout = [4, 4], sg_data = [2, 2] |
| /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit. |
| /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i] |
| /// |
| /// +------------------------+ |
| /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across |
| /// |-----+-----+-----| |
| /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down |
| /// |-----+-----+-----| |
| /// | 8x8 | 8x8 | 8x8 | |
| /// +------------------------+ |
| /// |
| /// Each 8x8 tile is further subdivided among subgroups: |
| /// +------------------------+ |
| /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns) |
| /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows) |
| /// | 2x2 2x2 2x2 2x2 | |
| /// | 2x2 2x2 2x2 2x2 | |
| /// +------------------------+ |
| /// |
| /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be |
| /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations. |
| |
| /// The pass currently has entire distribution logic in the WgToSgCreateNdOp |
| /// pattern and all the other ops just follow. |
| /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the |
| /// ops in the pass. |
| struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { |
| using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| SmallVector<SmallVector<OpFoldResult>> offsetsList; |
| if (failed(genOffsetsList(rewriter, op, offsetsList))) |
| return failure(); |
| |
| MLIRContext *ctx = op.getContext(); |
| xegpu::TensorDescType tdescTy = op.getType(); |
| ArrayRef<int64_t> wgShape = tdescTy.getShape(); |
| Type elemTy = tdescTy.getElementType(); |
| xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); |
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; |
| auto newTdescTy = |
| xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), |
| layout.dropSgLayoutAndData()); |
| |
| SmallVector<Value> newOps; |
| for (auto offsets : offsetsList) { |
| auto newOp = xegpu::CreateNdDescOp::create( |
| rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets, |
| op.getMixedSizes(), op.getMixedStrides()); |
| |
| newOps.push_back(newOp); |
| } |
| rewriter.replaceOpWithMultiple(op, {newOps}); |
| |
| return success(); |
| } |
| }; |
| |
| // This pattern transforms the CreateNdDescOp without offsets to create a |
| // subgroup descriptor from a workgroup descriptor |
| struct WgToSgCreateNdOpNoOffset |
| : public OpConversionPattern<xegpu::CreateNdDescOp> { |
| using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| // Check no offsets are specified. |
| if (!op.getMixedOffsets().empty()) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| MLIRContext *ctx = op.getContext(); |
| xegpu::TensorDescType tdescTy = op.getType(); |
| auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| Type elemTy = tdescTy.getElementType(); |
| ArrayRef<int64_t> wgShape = tdescTy.getShape(); |
| |
| SmallVector<int64_t> sgShape; |
| int count; |
| std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); |
| xegpu::TensorDescType newTdescTy = |
| xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), |
| layout.dropSgLayoutAndData()); |
| |
| SmallVector<Value> newCreateNdOps(count); |
| std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() { |
| return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, |
| op.getSource(), op.getMixedSizes(), |
| op.getMixedStrides()); |
| }); |
| |
| rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); |
| return success(); |
| } |
| }; |
| |
| /// This pattern transforms the LoadNdOp to load subgroup data. |
| struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> { |
| using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!op.getMixedOffsets().empty()) |
| return failure(); |
| |
| SmallVector<Value> newLoadOps; |
| for (auto src : adaptor.getTensorDesc()) { |
| xegpu::TensorDescType tdescTy = |
| dyn_cast<xegpu::TensorDescType>(src.getType()); |
| ArrayRef<int64_t> srcShape = tdescTy.getShape(); |
| VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType()); |
| auto newLoadOp = xegpu::LoadNdOp::create( |
| rewriter, op.getLoc(), newResTy, src, |
| xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs())); |
| newLoadOps.push_back(newLoadOp); |
| } |
| rewriter.replaceOpWithMultiple(op, {newLoadOps}); |
| return mlir::success(); |
| } |
| }; |
| |
| /// This pattern transforms the StoreNdOp to store to a subgroup descriptor |
| /// It creates a StoreNdOp op to store the updated values to the new subgroup |
| /// src tensor descriptors. |
| struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { |
| using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!op.getMixedOffsets().empty()) |
| return failure(); |
| |
| for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) |
| xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(), |
| op.getL2HintAttr(), op.getL3HintAttr()); |
| |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| // This pattern transforms the LoadNdOp with explicit offsets to load |
| // subgroup data. |
| struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { |
| using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| SmallVector<SmallVector<OpFoldResult>> offsetsList; |
| if (failed(genOffsetsList(rewriter, op, offsetsList))) |
| return failure(); |
| |
| xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); |
| if (layout) |
| layout = layout.dropSgLayoutAndData(); |
| SmallVector<Value> newOps; |
| for (auto [tdesc, offsets] : |
| llvm::zip(adaptor.getTensorDesc(), offsetsList)) { |
| auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType()); |
| VectorType newResTy = |
| VectorType::get(tdescTy.getShape(), tdescTy.getElementType()); |
| auto newOp = xegpu::LoadNdOp::create( |
| rewriter, op.getLoc(), newResTy, tdesc, offsets, |
| /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), |
| op.getL2HintAttr(), op.getL3HintAttr(), layout); |
| newOps.push_back(newOp); |
| } |
| rewriter.replaceOpWithMultiple(op, {newOps}); |
| |
| return success(); |
| } |
| }; |
| |
| // This pattern transforms the StoreNdOp with explicit offsets to store |
| // subgroup data. |
| struct WgToSgStoreNdOpWithOffset |
| : public OpConversionPattern<xegpu::StoreNdOp> { |
| using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| SmallVector<SmallVector<OpFoldResult>> offsetsList; |
| if (failed(genOffsetsList(rewriter, op, offsetsList))) |
| return failure(); |
| |
| xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); |
| if (layout) |
| layout = layout.dropSgLayoutAndData(); |
| for (auto [v, tdesc, offsets] : |
| llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { |
| xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets, |
| op.getL1HintAttr(), op.getL2HintAttr(), |
| op.getL3HintAttr(), layout); |
| } |
| rewriter.eraseOp(op); |
| |
| return success(); |
| } |
| }; |
| |
| // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch |
| // subgroup data. |
| struct WgToSgPrefetchNdOpWithOffset |
| : public OpConversionPattern<xegpu::PrefetchNdOp> { |
| using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| SmallVector<SmallVector<OpFoldResult>> offsetsList; |
| if (failed(genOffsetsList(rewriter, op, offsetsList))) |
| return failure(); |
| |
| xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); |
| if (layout) |
| layout = layout.dropSgLayoutAndData(); |
| for (auto [tdesc, offsets] : |
| llvm::zip(adaptor.getTensorDesc(), offsetsList)) { |
| xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets, |
| op.getL1HintAttr(), op.getL2HintAttr(), |
| op.getL3HintAttr(), layout); |
| } |
| rewriter.eraseOp(op); |
| |
| return success(); |
| } |
| }; |
| |
| /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a |
| /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the |
| /// offsets of the new subgroup src tensor descriptors. |
| struct WgToSgUpdateNdOffsetOp |
| : public OpConversionPattern<xegpu::UpdateNdOffsetOp> { |
| using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| llvm::SmallVector<Value> newUpdateTileOffsetOps; |
| for (auto tDesc : adaptor.getTensorDesc()) { |
| auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create( |
| rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(), |
| op.getConstOffsets()); |
| newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp); |
| } |
| |
| rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps}); |
| return success(); |
| } |
| }; |
| |
| /// This pattern transforms the DpasOp to work at subgroup level. |
| struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { |
| using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| VectorType resultTy = op.getResult().getType(); |
| if (resultTy.getRank() != 2) |
| return failure(); |
| |
| auto layoutCd = op.getLayoutCdAttr(); |
| auto layoutA = op.getLayoutAAttr(); |
| auto layoutB = op.getLayoutBAttr(); |
| if (!layoutCd || !layoutA || !layoutB) |
| return failure(); |
| size_t i = 0; |
| SmallVector<Value> newDpasOps; |
| for (auto aVec : adaptor.getLhs()) { |
| for (auto bVec : adaptor.getRhs()) { |
| |
| llvm::SmallVector<Value> operands({aVec, bVec}); |
| Value tmpC; |
| if (op.getAcc()) { |
| tmpC = adaptor.getAcc()[i++]; |
| operands.push_back(tmpC); |
| } |
| |
| ArrayRef<int64_t> aVecShape = |
| llvm::cast<VectorType>(aVec.getType()).getShape(); |
| ArrayRef<int64_t> bVecShape = |
| llvm::cast<VectorType>(bVec.getType()).getShape(); |
| VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, |
| resultTy.getElementType()); |
| auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands); |
| newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData()); |
| newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData()); |
| newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData()); |
| |
| newDpasOps.push_back(newDpasOp); |
| } |
| } |
| rewriter.replaceOpWithMultiple(op, {newDpasOps}); |
| return success(); |
| } |
| }; |
| |
| /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data. |
| struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> { |
| using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); |
| if ((offsetSize != 0) || op.getConstOffsetsAttr()) |
| return failure(); |
| |
| for (auto src : adaptor.getTensorDesc()) |
| xegpu::PrefetchNdOp::create( |
| rewriter, op.getLoc(), TypeRange(), src, |
| xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs())); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| /// This pattern transforms vector.broadcast ops to work at subgroup level. |
| struct WgToSgVectorBroadcastOp |
| : public OpConversionPattern<vector::BroadcastOp> { |
| using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| VectorType resultType = op.getResult().getType(); |
| ArrayRef<int64_t> wgShape = resultType.getShape(); |
| |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult())); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; |
| VectorType newResultType = |
| VectorType::get(sgShape, resultType.getElementType()); |
| |
| if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout)) |
| return failure(); |
| |
| SmallVector<Value> newBroadcastOps; |
| for (auto operand : adaptor.getOperands().front()) { |
| auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), |
| newResultType, operand); |
| |
| newBroadcastOps.push_back(newBroadcast.getResult()); |
| } |
| rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); |
| return success(); |
| } |
| }; |
| |
| // This pattern transforms elementwise ops to work at subgroup level. |
| struct WgToSgElementwiseOp : public ConversionPattern { |
| WgToSgElementwiseOp(MLIRContext *ctx) |
| : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} |
| |
| LogicalResult |
| matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // Only match ops with elementwise trait and single result. |
| if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) |
| return failure(); |
| |
| auto resultType = dyn_cast<VectorType>(op->getResult(0).getType()); |
| assert(resultType && "Expected result to be a VectorType"); |
| |
| ArrayRef<int64_t> wgShape = resultType.getShape(); |
| |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0))); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; |
| |
| size_t numVariants = operands.empty() ? 0 : operands.front().size(); |
| |
| if (llvm::any_of(operands, [&](const ValueRange &operandVec) { |
| return operandVec.size() != numVariants; |
| })) |
| return failure(); |
| |
| SmallVector<Value> newResults; |
| VectorType newResultType = |
| VectorType::get(sgShape, resultType.getElementType()); |
| |
| for (size_t i = 0; i < numVariants; ++i) { |
| SmallVector<Value> opOperands; |
| for (auto &operandVec : operands) |
| opOperands.push_back(operandVec[i]); |
| |
| OperationState state(op->getLoc(), op->getName()); |
| state.addOperands(opOperands); |
| state.addTypes(newResultType); |
| state.addAttributes(op->getAttrs()); |
| Operation *newOp = rewriter.create(state); |
| xegpu::removeLayoutAttrs(newOp); |
| newResults.push_back(newOp->getResult(0)); |
| } |
| |
| rewriter.replaceOpWithMultiple(op, {newResults}); |
| return success(); |
| } |
| }; |
| |
| // clang-format off |
| // Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data. |
| // If input_layout and target_layout have identical sg_layout and sg_data, |
| // the op is rewritten to a subgroup-level ConvertLayoutOp with these fields |
| // dropped. For example: |
| // #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]> |
| // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]> |
| // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32> |
| // becomes: |
| // #a = #xegpu.layout<inst_data = [16, 16]> |
| // #b = #xegpu.layout<inst_data = [8, 16]> |
| // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32> |
| // (vector<16x16xf32> is determined by sg_data = [16, 16]) |
| // |
| // If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups. |
| // For example: |
| // #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]> |
| // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]> |
| // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32> |
| // is lowered to: |
| // #a = #xegpu.layout<inst_data = [16, 16]> |
| // #b = #xegpu.layout<inst_data = [8, 16]> |
| // store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> |
| // %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> |
| // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> |
| // clang-format on |
| struct WgToSgConvertLayoutOp |
| : public OpConversionPattern<xegpu::ConvertLayoutOp> { |
| using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| |
| VectorType resultType = op.getResult().getType(); |
| ArrayRef<int64_t> wgShape = resultType.getShape(); |
| auto inputLayout = op.getInputLayout(); |
| auto targetLayout = op.getTargetLayout(); |
| |
| if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() || |
| !targetLayout.isForWorkgroup()) |
| return rewriter.notifyMatchFailure( |
| op, "Input and target layouts must have subgroup layout"); |
| |
| SmallVector<int64_t> inputSgLayout = |
| inputLayout.getEffectiveSgLayoutAsInt(); |
| SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt(); |
| SmallVector<int64_t> targetSgLayout = |
| targetLayout.getEffectiveSgLayoutAsInt(); |
| SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt(); |
| |
| // Fast path: if sg_layout and sg_data are identical, no SLM needed |
| if (inputLayout.isCompatibleWith(targetLayout, |
| xegpu::LayoutKind::Subgroup)) { |
| inputLayout = inputLayout.dropSgLayoutAndData(); |
| targetLayout = targetLayout.dropSgLayoutAndData(); |
| |
| SmallVector<Value> newOps(adaptor.getSource()); |
| if (inputLayout && targetLayout) { |
| for (auto [i, src] : llvm::enumerate(adaptor.getSource())) { |
| auto newOp = xegpu::ConvertLayoutOp::create( |
| rewriter, loc, src.getType(), src, inputLayout, targetLayout); |
| newOps[i] = newOp; |
| } |
| } |
| rewriter.replaceOpWithMultiple(op, {newOps}); |
| return success(); |
| } |
| |
| // SLM path: layouts differ, need cross-subgroup data redistribution |
| Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType(); |
| |
| SmallVector<int64_t> slmShape = llvm::to_vector(wgShape); |
| |
| // Calculate SLM size requirements |
| auto bitWidth = elemTy.getIntOrFloatBitWidth(); |
| auto bytesPerElement = bitWidth / 8; |
| auto slmSize = computeProduct(slmShape) * bytesPerElement; |
| |
| // Allocate SLM |
| auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3); |
| auto slm = memref::AllocaOp::create(rewriter, loc, slmTy); |
| |
| auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape, |
| elemTy, nullptr); |
| auto memDesc = |
| xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm); |
| |
| auto sgId = gpu::SubgroupIdOp::create(rewriter, loc, |
| rewriter.getIndexType(), nullptr); |
| |
| // STORE PHASE: Each subgroup stores in SLM using input layout |
| auto storeCoords = inputLayout.computeDistributedCoords( |
| rewriter, loc, sgId.getResult(), wgShape); |
| if (failed(storeCoords)) |
| return failure(); |
| |
| // Store to SLM |
| for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) { |
| SmallVector<OpFoldResult> storeMatrixOffsets; |
| for (Value coord : coords) { |
| storeMatrixOffsets.push_back(coord); |
| } |
| xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(), |
| storeMatrixOffsets, nullptr /*layout*/); |
| } |
| |
| gpu::BarrierOp::create(rewriter, loc); |
| |
| // LOAD PHASE: Each target subgroup loads from SLM using target layout |
| auto loadCoords = targetLayout.computeDistributedCoords( |
| rewriter, loc, sgId.getResult(), wgShape); |
| if (failed(loadCoords)) |
| return failure(); |
| |
| VectorType loadType = VectorType::get(targetSgData, elemTy); |
| |
| // Load vectors from SLM |
| SmallVector<Value> finalResults; |
| for (auto coords : *loadCoords) { |
| SmallVector<OpFoldResult> loadMatrixOffsets; |
| for (Value coord : coords) { |
| loadMatrixOffsets.push_back(coord); |
| } |
| auto loadOp = xegpu::LoadMatrixOp::create( |
| rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets, |
| targetLayout.dropSgLayoutAndData()); |
| |
| finalResults.push_back(loadOp.getResult()); |
| } |
| |
| rewriter.replaceOpWithMultiple(op, {finalResults}); |
| return success(); |
| } |
| }; |
| |
| // Handles UnrealizedConversionCastOp generated during |
| // SCFStructuralTypeConversions (step 1). This op may appear as either a |
| // target or source materialization for Vector values, e.g.: |
| // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ... |
| // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32> |
| // it could be either 1:N or N:1 cast. In both cases, the pattern |
| // simply forwards the inputs to the outputs using 1:1 or 1:N interface. |
| // for example, the following scf::forOp |
| // ``` |
| // %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) { |
| // %n = use(%arg1): vector<128x128xf16> |
| // scf.yield %n : vector<128x128xf16> |
| // } |
| // ``` |
| // Could be converted to: |
| // ``` |
| // %1 = unrealized_conversion_cast %0 |
| // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> |
| // %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2) |
| // -> (vector<16x16xf16>, vector<16x16xf16) { |
| // %m = unrealized_conversion_cast %arg1, %arg2 |
| // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> |
| // %n = use(%m): vector<128x128xf16> |
| // %b = unrealized_conversion_cast %n |
| // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> |
| // scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16> |
| // } |
| // %cast = unrealized_conversion_cast %for:2 |
| // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> |
| // ``` |
| // TODO: remove it when context-aware type converter is ready. |
| struct UnrealizedConversionCastOpPattern |
| : public OpConversionPattern<mlir::UnrealizedConversionCastOp> { |
| using OpConversionPattern< |
| mlir::UnrealizedConversionCastOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs()); |
| |
| auto inputTy = dyn_cast<VectorType>(inputs[0].getType()); |
| auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType()); |
| |
| if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) || |
| !llvm::all_equal(ValueRange(inputs).getTypes())) |
| return failure(); |
| |
| // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...". |
| // It is generated by source materialization (e.g., inits to scf forOp). |
| // The input values provided by the adaptor should already be distributed, |
| // and their types should correspond exactly to the result types of the |
| // operation. |
| if (op.getNumOperands() == 1 && |
| llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) { |
| rewriter.replaceOp(op, inputs); |
| return success(); |
| } |
| |
| // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>". |
| // It is generated by target materialization (e.g., arguments/results |
| // of scf forOp). All input values must have the same vector type, and |
| // their shape must be evenly divisible by the output vector's shape |
| // (determined by the nature of the workgroup to subgroup distribution). |
| // TODO: it is not safe to do such forward, since such N:1 cast could be |
| // from others. |
| if (op.getNumResults() == 1 && |
| computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { |
| rewriter.replaceOpWithMultiple(op, {inputs}); |
| return success(); |
| } |
| |
| return mlir::failure(); |
| } |
| }; |
| |
| // This pattern distributes arith.constant op into subgroup-level constants |
| struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { |
| using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue()); |
| auto vecType = dyn_cast<VectorType>(op.getType()); |
| if (!vecAttr || !vecType) |
| return failure(); |
| |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult())); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| ArrayRef<int64_t> wgShape = vecType.getShape(); |
| SmallVector<int64_t> sgShape; |
| int count; |
| std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); |
| |
| auto newType = VectorType::get(sgShape, vecType.getElementType()); |
| Location loc = op.getLoc(); |
| auto eltType = vecType.getElementType(); |
| |
| if (vecAttr.isSplat()) { |
| // Splat: single value for all subgroups |
| Attribute singleVal = vecAttr.getSplatValue<Attribute>(); |
| auto sgAttr = DenseElementsAttr::get(newType, singleVal); |
| auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); |
| rewriter.replaceOp(op, cstOp); |
| return success(); |
| } else if (sgShape == wgShape) { // if the entire vector is shared by all |
| // subgroups, don't distribute |
| auto newConstOp = |
| arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); |
| rewriter.replaceOp(op, newConstOp); |
| return success(); |
| } else { |
| // Non-splat constant |
| // Only supports 1D & 2D |
| // TODO: support other cases that require SLM access |
| if (!eltType.isIndex()) |
| return rewriter.notifyMatchFailure( |
| op, "Unsupported element type for non-splat constant op."); |
| |
| if (wgShape.size() > 2) |
| return rewriter.notifyMatchFailure( |
| op, "Only 1D & 2D vector constant supported"); |
| |
| SmallVector<Attribute> values(vecAttr.getValues<Attribute>()); |
| int64_t rowStride = 0, colStride = 0; |
| int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0]; |
| int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1]; |
| |
| // Compute colStride and rowStride, and check for constant strides. |
| if (cols > 1) { |
| colStride = cast<IntegerAttr>(values[1]).getInt() - |
| cast<IntegerAttr>(values[0]).getInt(); |
| } |
| if (rows > 1) { |
| rowStride = cast<IntegerAttr>(values[cols]).getInt() - |
| cast<IntegerAttr>(values[0]).getInt(); |
| } |
| |
| for (int64_t r = 0; r < rows; ++r) { |
| for (int64_t c = 0; c < cols; ++c) { |
| int64_t idx = r * cols + c; |
| // Check column stride |
| if (c > 0 && cols > 1) { |
| int64_t prevIdx = r * cols + (c - 1); |
| int64_t diff = cast<IntegerAttr>(values[idx]).getInt() - |
| cast<IntegerAttr>(values[prevIdx]).getInt(); |
| if (diff != colStride) |
| return rewriter.notifyMatchFailure( |
| op, "Non-constant column stride in constant op."); |
| } |
| // Check row stride |
| if (r > 0 && rows > 1) { |
| int64_t prevIdx = (r - 1) * cols + c; |
| int64_t diff = cast<IntegerAttr>(values[idx]).getInt() - |
| cast<IntegerAttr>(values[prevIdx]).getInt(); |
| if (diff != rowStride) |
| return rewriter.notifyMatchFailure( |
| op, "Non-constant row stride in constant op."); |
| } |
| } |
| } |
| |
| // Create a constant for the base tile. |
| // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. |
| // For 1D case, extract the first sgShape[0] elements. |
| SmallVector<Attribute> baseTileValues; |
| int baseTileCols = sgShape[sgShape.size() - 1]; |
| int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0]; |
| for (int64_t r = 0; r < baseTileRows; ++r) { |
| for (int64_t c = 0; c < baseTileCols; ++c) { |
| baseTileValues.push_back(values[r * cols + c]); |
| } |
| } |
| |
| auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), |
| baseTileValues); |
| auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr); |
| |
| // Get subgroup id |
| Value sgId = |
| gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); |
| auto sgOffsets = |
| layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); |
| if (failed(sgOffsets)) |
| return failure(); |
| |
| SmallVector<Value, 2> strideConsts; |
| strideConsts.push_back( |
| arith::ConstantIndexOp::create(rewriter, loc, colStride)); |
| if (rows > 1) |
| strideConsts.insert( |
| strideConsts.begin(), |
| arith::ConstantIndexOp::create(rewriter, loc, rowStride)); |
| |
| SmallVector<Value> newConstOps; |
| for (auto offsets : *sgOffsets) { |
| // Multiply offset with stride, broadcast it and add to baseConstVec |
| Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); |
| for (size_t i = 0; i < strideConsts.size(); ++i) { |
| Value mul = |
| arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(), |
| offsets[i], strideConsts[i]); |
| mulOffset = arith::AddIOp::create( |
| rewriter, loc, rewriter.getIndexType(), mulOffset, mul); |
| } |
| // Broadcast to baseConstVec size |
| auto bcastOffset = vector::BroadcastOp::create( |
| rewriter, loc, baseConstVec.getType(), mulOffset); |
| auto finalConst = |
| arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); |
| newConstOps.push_back(finalConst); |
| } |
| rewriter.replaceOpWithMultiple(op, {newConstOps}); |
| return success(); |
| } |
| } |
| }; |
| |
| // This pattern transforms the LoadGatherOp with explicit offsets to load |
| // subgroup data |
| struct WgToSgLoadGatherOpWithOffset |
| : public OpConversionPattern<xegpu::LoadGatherOp> { |
| using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| if (!op.getOffsets()) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| VectorType resultType = dyn_cast<VectorType>(op.getResult().getType()); |
| if (!resultType) |
| return failure(); |
| ArrayRef<int64_t> wgShape = resultType.getShape(); |
| |
| xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); |
| |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; |
| |
| // The offsets need to be distributed |
| auto offsetsVecType = |
| dyn_cast<VectorType>(adaptor.getOffsets().front().getType()); |
| auto maskVecType = |
| dyn_cast<VectorType>(adaptor.getMask().front().getType()); |
| if (!offsetsVecType || !maskVecType || |
| offsetsVecType.getShape() != maskVecType.getShape()) { |
| return rewriter.notifyMatchFailure(op, |
| "offsets have not been distributed"); |
| } |
| |
| SmallVector<Value> newLoadOps; |
| auto chunkSizeAttr = |
| rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); |
| VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); |
| for (auto [offsets, mask] : |
| llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { |
| auto newLayout = layout.dropSgLayoutAndData(); |
| auto newLoadOp = xegpu::LoadGatherOp::create( |
| rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, |
| op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), |
| newLayout); |
| newLoadOps.push_back(newLoadOp); |
| } |
| rewriter.replaceOpWithMultiple(op, {newLoadOps}); |
| return success(); |
| } |
| }; |
| |
| // This pattern transforms the StoreScatterOp with explicit offsets to store |
| // subgroup data |
| struct WgToSgStoreScatterOpWithOffset |
| : public OpConversionPattern<xegpu::StoreScatterOp> { |
| using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| if (!op.getOffsets()) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| VectorType valueType = dyn_cast<VectorType>(op.getValue().getType()); |
| if (!valueType) |
| return failure(); |
| |
| xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); |
| |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| // The offsets need to be distributed |
| auto offsetsVecType = |
| dyn_cast<VectorType>(adaptor.getOffsets().front().getType()); |
| auto maskVecType = |
| dyn_cast<VectorType>(adaptor.getMask().front().getType()); |
| if (!offsetsVecType || !maskVecType || |
| offsetsVecType.getShape() != maskVecType.getShape()) { |
| return rewriter.notifyMatchFailure(op, |
| "offsets have not been distributed"); |
| } |
| |
| auto chunkSizeOpt = op.getChunkSize(); |
| int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1; |
| auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); |
| for (auto [val, offs, mask] : llvm::zip( |
| adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { |
| xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs, |
| mask, chunkSizeAttr, op.getL1HintAttr(), |
| op.getL2HintAttr(), op.getL3HintAttr(), |
| layout.dropSgLayoutAndData()); |
| } |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> { |
| using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| SmallVector<SmallVector<OpFoldResult>> offsetsList; |
| if (failed(genOffsetsList(rewriter, op, offsetsList))) |
| return failure(); |
| |
| ArrayRef<int64_t> wgShape = op.getDataShape(); |
| VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType()); |
| assert(valueTy && "the value type must be vector type!"); |
| Type elemTy = valueTy.getElementType(); |
| |
| xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); |
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; |
| VectorType newResTy = VectorType::get(sgShape, elemTy); |
| SmallVector<Value> newOps; |
| for (auto offsets : offsetsList) { |
| auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy, |
| op.getMemDesc(), offsets, |
| layout.dropSgLayoutAndData()); |
| newOps.push_back(newOp); |
| } |
| rewriter.replaceOpWithMultiple(op, {newOps}); |
| |
| return success(); |
| } |
| }; |
| |
| struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> { |
| using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| SmallVector<SmallVector<OpFoldResult>> offsetsList; |
| if (failed(genOffsetsList(rewriter, op, offsetsList))) |
| return failure(); |
| |
| xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); |
| for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList)) |
| xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(), |
| offsets, layout.dropSgLayoutAndData()); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| // This pattern distributes the vector.step ops to work at subgroup level |
| struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> { |
| using OpConversionPattern<vector::StepOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult())); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| VectorType type = op.getResult().getType(); |
| auto wgShape = type.getShape(); |
| std::optional<SmallVector<int64_t>> sgShape = |
| getSgShapeAndCount(wgShape, layout).first; |
| if (!sgShape) |
| return failure(); |
| |
| Value sgId = |
| gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); |
| auto sgOffsets = |
| layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); |
| if (failed(sgOffsets)) |
| return failure(); |
| |
| VectorType newTy = type.cloneWith(*sgShape, type.getElementType()); |
| auto steps = vector::StepOp::create(rewriter, loc, newTy); |
| SmallVector<Value> newOps; |
| for (auto offsets : *sgOffsets) { |
| // Broadcast the offset scalar to a vector & add to the base steps |
| auto bcastOffset = |
| vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); |
| auto finalSteps = |
| arith::AddIOp::create(rewriter, loc, steps, bcastOffset); |
| newOps.push_back(finalSteps); |
| } |
| |
| rewriter.replaceOpWithMultiple(op, {newOps}); |
| return success(); |
| } |
| }; |
| |
| // This pattern transforms vector.shape_cast ops to work at subgroup level. |
| struct WgToSgVectorShapeCastOp |
| : public OpConversionPattern<vector::ShapeCastOp> { |
| using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| VectorType resultType = dyn_cast<VectorType>(op.getResult().getType()); |
| if (!resultType) |
| return failure(); |
| |
| ArrayRef<int64_t> wgShape = resultType.getShape(); |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult())); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| // Check that srcShape and destShape, if they differ, only differ by |
| // expand of unit dimensions. |
| auto srcType = dyn_cast<VectorType>(op.getSource().getType()); |
| if (!srcType) |
| return failure(); |
| |
| ArrayRef<int64_t> srcShape = srcType.getShape(); |
| |
| xegpu::DistributeLayoutAttr layoutToDistribute = layout; |
| SmallVector<int64_t> expandedUnitDims; |
| if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) { |
| xegpu::DistributeLayoutAttr sourceLayout = |
| xegpu::getTemporaryLayout(op->getOpOperand(0)); |
| |
| auto usedByBroadcastOp = [](vector::ShapeCastOp op) { |
| return llvm::all_of(op.getResult().getUsers(), [](Operation *user) { |
| return isa<vector::BroadcastOp>(user); |
| }); |
| }; |
| |
| if (!usedByBroadcastOp(op)) |
| return rewriter.notifyMatchFailure( |
| op, "ShapeCast ops that expand unit dimensions and are used by " |
| "non-broadcast operations are not supported."); |
| |
| if (!sourceLayout.isSliceOf(layout)) |
| return rewriter.notifyMatchFailure( |
| op, "The ShapeCast op only expands dimensions, the result layout " |
| "must be a slice of the input layout, or vice versa."); |
| layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims); |
| layoutToDistribute = |
| layoutToDistribute.setUnitDimLayout(expandedUnitDims); |
| } |
| |
| SmallVector<int64_t> sgShape = |
| getSgShapeAndCount(wgShape, layoutToDistribute).first; |
| VectorType newResultType = |
| VectorType::get(sgShape, resultType.getElementType()); |
| |
| SmallVector<Value> newShapeCastOps; |
| for (auto src : adaptor.getSource()) { |
| auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), |
| newResultType, src); |
| newShapeCastOps.push_back(newShapeCast.getResult()); |
| } |
| |
| rewriter.replaceOpWithMultiple(op, {newShapeCastOps}); |
| return success(); |
| } |
| }; |
| |
| static Value createAccumulator(ConversionPatternRewriter &rewriter, |
| Location loc, VectorType type, |
| vector::CombiningKind kind) { |
| Type elemTy = type.getElementType(); |
| |
| switch (kind) { |
| case vector::CombiningKind::ADD: |
| case vector::CombiningKind::XOR: |
| case vector::CombiningKind::OR: |
| return arith::ConstantOp::create( |
| rewriter, loc, type, |
| DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy))); |
| |
| case vector::CombiningKind::MUL: |
| case vector::CombiningKind::AND: |
| return arith::ConstantOp::create( |
| rewriter, loc, type, |
| DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy))); |
| |
| case vector::CombiningKind::MINSI: |
| // Use max signed int value for signed integer min |
| if (auto intTy = dyn_cast<IntegerType>(elemTy)) { |
| auto maxVal = APInt::getSignedMaxValue(intTy.getWidth()); |
| return arith::ConstantOp::create( |
| rewriter, loc, type, |
| DenseElementsAttr::get(type, |
| rewriter.getIntegerAttr(elemTy, maxVal))); |
| } |
| return nullptr; |
| |
| case vector::CombiningKind::MINUI: |
| if (auto intTy = dyn_cast<IntegerType>(elemTy)) { |
| auto maxVal = APInt::getMaxValue(intTy.getWidth()); |
| return arith::ConstantOp::create( |
| rewriter, loc, type, |
| DenseElementsAttr::get(type, |
| rewriter.getIntegerAttr(elemTy, maxVal))); |
| } |
| return nullptr; |
| |
| case vector::CombiningKind::MAXSI: |
| if (auto intTy = dyn_cast<IntegerType>(elemTy)) { |
| auto minVal = APInt::getSignedMinValue(intTy.getWidth()); |
| return arith::ConstantOp::create( |
| rewriter, loc, type, |
| DenseElementsAttr::get(type, |
| rewriter.getIntegerAttr(elemTy, minVal))); |
| } |
| return nullptr; |
| |
| case vector::CombiningKind::MAXUI: |
| return arith::ConstantOp::create( |
| rewriter, loc, type, |
| DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy))); |
| |
| case vector::CombiningKind::MINNUMF: |
| case vector::CombiningKind::MINIMUMF: |
| // Use +infinity for float min operations |
| if (auto floatTy = dyn_cast<FloatType>(elemTy)) { |
| auto posInf = APFloat::getInf(floatTy.getFloatSemantics()); |
| return arith::ConstantOp::create( |
| rewriter, loc, type, |
| DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf))); |
| } |
| return nullptr; |
| |
| case vector::CombiningKind::MAXNUMF: |
| case vector::CombiningKind::MAXIMUMF: |
| // Use -infinity for float max operations |
| if (auto floatTy = dyn_cast<FloatType>(elemTy)) { |
| auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true); |
| return arith::ConstantOp::create( |
| rewriter, loc, type, |
| DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf))); |
| } |
| return nullptr; |
| } |
| return nullptr; |
| } |
| |
| /// This pattern transforms vector.multi_dim_reduction operations from |
| /// workgroup-level to subgroup-level execution with support for multiple |
| /// reduction dimensions. |
| /// |
| /// Steps include: |
| /// 1. LOCAL REDUCTION : |
| /// - Each subgroup performs local reduction on its data slice |
| /// - Uses ZERO accumulator to avoid double-counting during cross-subgroup |
| /// phase |
| /// |
| /// 2. CROSS-SUBGROUP : |
| /// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in |
| /// reduction dims & sgData[reduction dims] < wgData[reduction dims]) |
| /// - If not needed, adds original accumulator and returns local results |
| /// |
| /// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed): |
| /// a) SLM Layout Design: |
| /// - Rows: subgroups participating in reduction (product of sg_layout in |
| /// reduction dims) |
| /// - Cols: total result elements across non-reduction dimensions |
| /// |
| /// b) Store Phase: |
| /// - Each subgroup stores its local reduction result to SLM |
| /// - Row offset: linearized index of subgroup in reduction dimensions |
| /// - Col offset: linearized index of subgroup in non-reduction dimensions |
| /// |
| /// c) Load and Final Reduction Phase: |
| /// - Each subgroup loads a column of data (all reduction participants for |
| /// its position) |
| /// - Performs final reduction along the loaded dimension |
| /// - Adds original accumulator to get final result |
| /// |
| struct WgToSgMultiDimReductionOp |
| : public OpConversionPattern<vector::MultiDimReductionOp> { |
| using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| |
| VectorType srcType = op.getSourceVectorType(); |
| VectorType dstType = dyn_cast<VectorType>(op.getResult().getType()); |
| if (!dstType) |
| return failure(); |
| |
| auto originalSrcShape = srcType.getShape(); |
| auto originalDstShape = dstType.getShape(); |
| int srcVecRank = originalSrcShape.size(); |
| |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult())); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| auto reductionDims = llvm::to_vector(op.getReductionDims()); |
| |
| // Get sg_layout and sg_data from the parent layout |
| SmallVector<int64_t> sgLayout; |
| SmallVector<int64_t> sgData; |
| if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) { |
| sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt(); |
| sgData = sliceAttr.getParent().getEffectiveSgDataAsInt(); |
| } else |
| return rewriter.notifyMatchFailure( |
| op, "Reduction should have SliceAttr layout"); |
| |
| Type elemTy = dstType.getElementType(); |
| |
| // Step 1: perform local subgroup reductions with ZERO accumulator |
| SmallVector<Value> localReductions; |
| SmallVector<int64_t> sgDstShape = |
| getSgShapeAndCount(originalDstShape, layout).first; |
| auto sgSrcs = adaptor.getSource(); |
| auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType()); |
| SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(), |
| sgSrcType.getShape().end()); |
| |
| VectorType newDstType = VectorType::get(sgDstShape, elemTy); |
| for (auto sgSrc : sgSrcs) { |
| // Create ZERO accumulator for local reduction |
| auto neutralLocalAcc = |
| createAccumulator(rewriter, loc, newDstType, op.getKind()); |
| // Local reduction with ZERO accumulator |
| auto localReduce = vector::MultiDimReductionOp::create( |
| rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc, |
| reductionDims); |
| localReductions.push_back(localReduce.getResult()); |
| } |
| |
| // Check if cross-subgroup reduction is needed for any reduction dimension |
| SmallVector<int64_t> crossSgReductionDims; |
| for (int64_t reductionDim : reductionDims) { |
| bool needsCrossSubgroupReduction = |
| (sgLayout[reductionDim] > 1) && |
| (sgData[reductionDim] < originalSrcShape[reductionDim]); |
| |
| if (needsCrossSubgroupReduction) { |
| crossSgReductionDims.push_back(reductionDim); |
| } |
| } |
| |
| // If no cross-subgroup reduction needed, add accumulator and return |
| if (crossSgReductionDims.empty()) { |
| SmallVector<Value> results; |
| for (auto localResult : localReductions) { |
| auto finalResult = vector::makeArithReduction( |
| rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]); |
| results.push_back(finalResult); |
| } |
| rewriter.replaceOpWithMultiple(op, {results}); |
| return success(); |
| } |
| |
| // Step 2: cross-subgroup reduction using SLM |
| auto slmStoreDataShape = sgSrcShape; |
| for (int64_t dim : reductionDims) |
| slmStoreDataShape[dim] = 1; |
| VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy); |
| Value slmStoreData = vector::ShapeCastOp::create( |
| rewriter, loc, slmStoreDataType, localReductions[0]); |
| |
| SmallVector<int64_t> slmShape(originalSrcShape.begin(), |
| originalSrcShape.end()); |
| // for reduction dimension, SLM stores partial results from each subgroup |
| for (int64_t dim : reductionDims) |
| slmShape[dim] = sgLayout[dim]; |
| |
| // Allocate SLM |
| auto bitWidth = elemTy.getIntOrFloatBitWidth(); |
| auto bytesPerElement = bitWidth / 8; |
| auto slmSize = computeProduct(slmShape) * bytesPerElement; |
| auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3); |
| auto slm = memref::AllocaOp::create(rewriter, loc, slmTy); |
| |
| auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape, |
| elemTy, nullptr); |
| auto memDesc = |
| xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm); |
| |
| // if localReductions have more than 1 result, not support |
| if (localReductions.size() > 1) { |
| return rewriter.notifyMatchFailure( |
| op, |
| "Multiple local reductions not supported in current implementation."); |
| } |
| |
| // Step 4: Store local results to SLM |
| auto sgId = gpu::SubgroupIdOp::create(rewriter, loc, |
| rewriter.getIndexType(), nullptr); |
| |
| // Convert sgLayout to Values for delinearizeIndex |
| SmallVector<Value> sgLayoutValues; |
| for (int64_t dim : sgLayout) |
| sgLayoutValues.push_back( |
| arith::ConstantIndexOp::create(rewriter, loc, dim)); |
| |
| auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(), |
| sgLayoutValues); |
| if (failed(sgIdsResult)) |
| return failure(); |
| SmallVector<Value> sgIds = *sgIdsResult; |
| |
| auto getSlmOffsets = [&](int64_t reductionDimStride) { |
| SmallVector<OpFoldResult> offsets; |
| offsets.reserve(srcVecRank); |
| for (int i = 0; i < srcVecRank; ++i) { |
| Value dimVal = sgIds[i]; |
| int64_t sgDataStride = (llvm::is_contained(reductionDims, i)) |
| ? reductionDimStride |
| : sgSrcShape[i]; |
| Value strideVal = |
| arith::ConstantIndexOp::create(rewriter, loc, sgDataStride); |
| Value offsetVal = |
| arith::MulIOp::create(rewriter, loc, dimVal, strideVal); |
| offsets.push_back(offsetVal); |
| } |
| return offsets; |
| }; |
| |
| SmallVector<OpFoldResult> slmStoreOffsets = |
| getSlmOffsets(/*reductionDimStride=*/1); |
| |
| xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData, |
| memDesc.getResult(), slmStoreOffsets, |
| /*layout=*/nullptr); |
| |
| gpu::BarrierOp::create(rewriter, loc); |
| |
| // Step 5: Load from SLM for final reduction |
| SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end()); |
| for (int64_t dim : reductionDims) |
| slmLoadDataShape[dim] = slmShape[dim]; |
| |
| SmallVector<OpFoldResult> slmLoadOffsets = |
| getSlmOffsets(/*reductionDimStride=*/0); |
| |
| VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy); |
| auto slmLoadOp = xegpu::LoadMatrixOp::create( |
| rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets, |
| /*layout=*/nullptr); |
| |
| // Step 6: Perform final reduction with ZERO accumulator |
| auto neutralFinalAcc = |
| createAccumulator(rewriter, loc, newDstType, op.getKind()); |
| |
| auto finalReduce = vector::MultiDimReductionOp::create( |
| rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(), |
| neutralFinalAcc, reductionDims); |
| |
| // Step 7: Add the original accumulator at the end |
| auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(), |
| finalReduce.getResult(), |
| adaptor.getAcc()[0]); |
| |
| rewriter.replaceOp(op, finalResult); |
| return success(); |
| } |
| }; |
| |
| // This pattern transforms vector.transpose ops to work at subgroup level. |
| struct WgToSgVectorTransposeOp |
| : public OpConversionPattern<vector::TransposeOp> { |
| using OpConversionPattern<vector::TransposeOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| VectorType resultType = op.getResultVectorType(); |
| |
| ArrayRef<int64_t> wgShape = resultType.getShape(); |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult())); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| // TODO-LayoutRefactor: handle the case using getTemporaryLayout |
| xegpu::DistributeLayoutAttr sourceLayout = |
| xegpu::getDistributeLayoutAttr(op.getVector()); |
| if (!sourceLayout || !sourceLayout.isForWorkgroup()) |
| return failure(); |
| |
| SmallVector<int64_t> sourceSgLayout = |
| sourceLayout.getEffectiveSgLayoutAsInt(); |
| SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt(); |
| |
| ArrayRef<int64_t> permutation = op.getPermutation(); |
| size_t permutationSize = permutation.size(); |
| if (sourceSgLayout.size() != permutationSize || |
| resultSgLayout.size() != permutationSize) { |
| return rewriter.notifyMatchFailure( |
| op, "Layouts and permutation must have the same rank"); |
| } |
| |
| // Check that sgLayout, sgData & order are properly transposed for source |
| // and result |
| if (!layout.isTransposeOf(sourceLayout, permutation, |
| xegpu::LayoutKind::Subgroup)) |
| return rewriter.notifyMatchFailure( |
| op, "Result layout is not a valid transpose of source layout " |
| "according to permutation"); |
| |
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; |
| VectorType newResultType = |
| VectorType::get(sgShape, resultType.getElementType()); |
| |
| SmallVector<Value> newTransposeOps; |
| for (auto src : adaptor.getVector()) { |
| auto newTranspose = vector::TransposeOp::create( |
| rewriter, op.getLoc(), newResultType, src, permutation); |
| newTransposeOps.push_back(newTranspose.getResult()); |
| } |
| rewriter.replaceOpWithMultiple(op, {newTransposeOps}); |
| return success(); |
| } |
| }; |
| |
| // Distribute vector mask ops to work at subgroup level. |
| template <typename MaskOpType> |
| struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> { |
| using OpConversionPattern<MaskOpType>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| MaskOpType op, |
| typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult())); |
| if (!layout || !layout.isForWorkgroup()) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| VectorType type = op.getResult().getType(); |
| auto wgShape = type.getShape(); |
| |
| SmallVector<Value> wgMaskDimSizes; |
| if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) { |
| for (int64_t maskSize : op.getMaskDimSizes()) { |
| wgMaskDimSizes.push_back( |
| arith::ConstantIndexOp::create(rewriter, loc, maskSize)); |
| } |
| } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) { |
| wgMaskDimSizes = llvm::to_vector(op.getOperands()); |
| } |
| |
| Value sgId = |
| gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); |
| auto sgOffsets = |
| layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); |
| if (failed(sgOffsets)) |
| return failure(); |
| |
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; |
| VectorType resultType = VectorType::get(sgShape, type.getElementType()); |
| |
| // In each dimension, each subgroup computes its local mask size as: |
| // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d]) |
| SmallVector<Value> newCreateMaskOps; |
| for (auto offsetSet : *sgOffsets) { |
| SmallVector<Value> maskOperands; |
| |
| for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) { |
| Value dimSizeVal = |
| arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); |
| Value offset = offsetSet[i]; |
| Value adjustedMaskSize = |
| arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset); |
| Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); |
| Value nonNegative = |
| arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); |
| Value sgMaskSize = |
| arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal); |
| maskOperands.push_back(sgMaskSize); |
| } |
| |
| auto newCreateMaskOp = |
| vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands); |
| newCreateMaskOps.push_back(newCreateMaskOp.getResult()); |
| } |
| |
| rewriter.replaceOpWithMultiple(op, {newCreateMaskOps}); |
| return success(); |
| } |
| }; |
| |
| using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>; |
| using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>; |
| } // namespace |
| |
| namespace mlir { |
| namespace xegpu { |
| void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { |
| patterns |
| .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp, |
| WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset, |
| WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, |
| WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, |
| WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, |
| WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, |
| WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, |
| WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, |
| WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, |
| WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( |
| patterns.getContext()); |
| } |
| } // namespace xegpu |
| } // namespace mlir |
| |
| namespace { |
| struct XeGPUWgToSgDistributePass |
| : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> { |
| void runOnOperation() override; |
| }; |
| } // namespace |
| |
| void XeGPUWgToSgDistributePass::runOnOperation() { |
| |
| Operation *op = getOperation(); |
| if (!xegpu::recoverTemporaryLayouts(op)) { |
| signalPassFailure(); |
| return; |
| } |
| |
| // Track existing UnrealizedConversionCastOps |
| SmallVector<Operation *> existingCastOps; |
| getOperation()->walk([&](UnrealizedConversionCastOp castOp) { |
| existingCastOps.push_back(castOp.getOperation()); |
| }); |
| |
| { |
| // Step 1: Apply SCFStructuralTypeConversions to SCF operations with |
| // VectorType operands. This first converts such operands to |
| // RankedTensorType, propagates the layout attribute into the encoding |
| // attribute, and finally converts the RankedTensorType to VectorType based |
| // on the encoding. |
| |
| TypeConverter converter; |
| converter.addConversion([&](Type type) -> Type { return type; }); |
| converter.addConversion( |
| [&](RankedTensorType type, |
| SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> { |
| // Only convert RankedTensorTypes that carry an XeGPU layout encoding. |
| // Plain tensors (e.g. tensor<?xi32>) have no XeGPU encoding and must |
| // not be converted: VectorType does not support dynamic dimensions. |
| auto encoding = |
| dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()); |
| if (!encoding) |
| return std::nullopt; |
| |
| Type elemTy = type.getElementType(); |
| ArrayRef<int64_t> shape = type.getShape(); |
| |
| int count; |
| SmallVector<int64_t> subShape; |
| std::tie(subShape, count) = getSgShapeAndCount(shape, encoding); |
| |
| auto newTy = VectorType::get(subShape, elemTy); |
| result.append(count, newTy); |
| return success(); |
| }); |
| |
| xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), |
| converter); |
| } |
| |
| // Step 2: Perform workgroup to subgroup distribution for TensorDesc values, |
| // as well as XeGPU, Arith, and Vector operations. |
| MLIRContext *ctx = &getContext(); |
| RewritePatternSet patterns(ctx); |
| ConversionTarget target(*ctx); |
| TypeConverter converter; |
| converter.addConversion([&](Type type) -> Type { return type; }); |
| converter.addConversion( |
| [&](xegpu::TensorDescType type, |
| SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> { |
| Type elemTy = type.getElementType(); |
| ArrayRef<int64_t> shape = type.getShape(); |
| |
| int count; |
| SmallVector<int64_t> subShape; |
| xegpu::LayoutAttr layout = type.getLayoutAttr(); |
| std::tie(subShape, count) = getSgShapeAndCount(shape, layout); |
| |
| if (layout) |
| layout = layout.dropSgLayoutAndData(); |
| |
| auto newTy = xegpu::TensorDescType::get( |
| type.getContext(), subShape, elemTy, type.getEncoding(), layout); |
| result.append(count, newTy); |
| return success(); |
| }); |
| |
| auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { |
| if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op)) |
| return createOp.getType(); |
| if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op)) |
| return loadOp.getTensorDescType(); |
| if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op)) |
| return storeOp.getTensorDescType(); |
| if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) |
| return updateOp.getType(); |
| if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op)) |
| return prefetchOp.getTensorDescType(); |
| return xegpu::TensorDescType(); |
| }; |
| |
| auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool { |
| return !layout || !layout.isForWorkgroup(); |
| }; |
| |
| target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp, |
| xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp, |
| xegpu::PrefetchNdOp>([=](Operation *op) -> bool { |
| auto tdescTy = getTensorDescType(op); |
| auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout()); |
| return isLegal(layout); |
| }); |
| |
| target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool { |
| auto layout = op.getLayoutCdAttr(); |
| return isLegal(layout); |
| }); |
| |
| target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>( |
| [=](xegpu::LoadMatrixOp op) -> bool { |
| return isLegal(op.getLayoutAttr()); |
| }); |
| |
| target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>( |
| [=](xegpu::StoreMatrixOp op) -> bool { |
| return isLegal(op.getLayoutAttr()); |
| }); |
| |
| target.addDynamicallyLegalOp<arith::ConstantOp>( |
| [=](arith::ConstantOp op) -> bool { |
| auto vecType = dyn_cast<VectorType>(op.getType()); |
| if (!vecType) |
| return true; |
| |
| auto layout = |
| xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult())); |
| return isLegal(layout); |
| }); |
| |
| target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, |
| vector::TransposeOp, vector::BroadcastOp, |
| vector::MultiDimReductionOp, |
| vector::ConstantMaskOp, vector::CreateMaskOp>( |
| [=](Operation *op) -> bool { |
| // Check for either a SliceAttr or LayoutAttr on the result. |
| auto layout = |
| xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0))); |
| return isLegal(layout); |
| }); |
| |
| target.addDynamicallyLegalOp<xegpu::LoadGatherOp>( |
| [=](xegpu::LoadGatherOp op) -> bool { |
| auto layout = op.getLayoutAttr(); |
| return isLegal(layout); |
| }); |
| |
| target.addDynamicallyLegalOp<xegpu::StoreScatterOp>( |
| [=](xegpu::StoreScatterOp op) -> bool { |
| auto layout = op.getLayoutAttr(); |
| return isLegal(layout); |
| }); |
| |
| target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( |
| [=](xegpu::ConvertLayoutOp op) -> bool { |
| return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); |
| }); |
| |
| target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>( |
| [=](Operation *op) -> std::optional<bool> { |
| // Only handle elementwise mappable ops |
| if (!OpTrait::hasElementwiseMappableTraits(op)) |
| return true; |
| |
| VectorType resultType = |
| dyn_cast<VectorType>(op->getResult(0).getType()); |
| if (!resultType) |
| return true; |
| |
| // Check if all operands are vectors of the same shape |
| // TODO: Support other types. |
| for (Value operand : op->getOperands()) { |
| VectorType operandType = dyn_cast<VectorType>(operand.getType()); |
| if (!operandType || operandType.getShape() != resultType.getShape()) { |
| return true; |
| } |
| } |
| |
| xegpu::DistributeLayoutAttr layout = |
| xegpu::getTemporaryLayout(op->getResult(0)); |
| return isLegal(layout); |
| }); |
| |
| target.addDynamicallyLegalOp<UnrealizedConversionCastOp>( |
| [=](UnrealizedConversionCastOp op) { |
| return llvm::is_contained(existingCastOps, op.getOperation()); |
| }); |
| |
| target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
| |
| scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, |
| target); |
| xegpu::populateXeGPUWgToSgDistributePatterns(patterns); |
| if (failed( |
| applyPartialConversion(getOperation(), target, std::move(patterns)))) |
| return signalPassFailure(); |
| |
| // Remove layout attributes from SCF ops |
| getOperation()->walk([](Operation *op) { |
| if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op)) |
| return; |
| |
| SmallVector<StringAttr> attrsToRemove; |
| for (auto namedAttr : op->getDiscardableAttrs()) { |
| if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue())) |
| attrsToRemove.push_back(namedAttr.getName()); |
| } |
| for (auto attrName : attrsToRemove) |
| op->removeDiscardableAttr(attrName); |
| }); |
| } |