| //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This pass legalizes vector operations so they can be lowered to ArmSME. |
| // |
| // Note: In the context of this pass 'tile' always refers to an SME tile. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
| #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
| #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
| #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| #include "mlir/Dialect/Index/IR/IndexOps.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "arm-sme-vector-legalization" |
| |
| namespace mlir::arm_sme { |
| #define GEN_PASS_DEF_VECTORLEGALIZATION |
| #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
| } // namespace mlir::arm_sme |
| |
| using namespace mlir; |
| using namespace mlir::arm_sme; |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Decomposition of vector operations larger than an SME tile |
| //===----------------------------------------------------------------------===// |
| |
| // Common match failure reasons. |
| static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( |
| "op vector size is not multiple of SME tiles"); |
| static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( |
| "op mask is unsupported for legalization/decomposition"); |
| static constexpr StringLiteral |
| kMatchFailureNonPermutationMap("op affine map is not a permutation"); |
| static constexpr StringLiteral kMatchFailureNotIllegalToLegal( |
| "expected transpose from illegal type to legal type"); |
| |
| /// An SMESubTile represents a single SME-sized sub-tile from decomposing a |
| /// larger vector type. The (`row`, `col`) are the position of the tile in the |
| /// original vector type. For example for an [8]x[8] tile with four [4]x[4] |
| /// sub-tiles, we would have: |
| /// |
| /// 8 x vscale |
| /// ┌─────────────┬─────────────┐ |
| /// │(0,0) │(0,4) │ |
| /// │ │ │ |
| /// ├─────────────┼─────────────┤ 8 x vscale |
| /// │(4,0) │(4,4) │ |
| /// │ │ │ |
| /// └─────────────┴─────────────┘ |
| struct SMESubTile { |
| // Note: The units of (row, col) are vscale (as SME tiles are scalable). |
| int row{0}; |
| int col{0}; |
| // The SME tile type. |
| VectorType type; |
| }; |
| |
| /// Adds a constant elementwise scalable offset to `indices` (which are of equal |
| /// length). For example, in the 2D case this would return: |
| // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } |
| SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, |
| Location loc, |
| ValueRange indices, |
| ArrayRef<int> scalableOffsets) { |
| auto vscale = builder.create<vector::VectorScaleOp>(loc); |
| return llvm::map_to_vector( |
| llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value { |
| auto [index, base] = pair; |
| auto offset = builder.create<arith::MulIOp>( |
| loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale); |
| return builder.create<arith::AddIOp>(loc, index, offset); |
| }); |
| } |
| |
| /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to |
| /// indices for one of the SME sub-tiles it will decompose into. |
| /// |
| /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the |
| /// indices for each tile would need to be adjusted as follows: |
| /// |
| /// initial indices = [a,b], inital size = 8x8, target size = 4x4 |
| /// ┌─────────────┬─────────────┐ |
| /// │[a,b] │[a,b+4] │ |
| /// │ │ │ |
| /// ├─────────────┼─────────────┤ |
| /// │[a+4,b] │[a+4,b+4] │ |
| /// │ │ │ |
| /// └─────────────┴─────────────┘ |
| SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc, |
| ValueRange indices, |
| SMESubTile smeTile) { |
| return addConstantScalableOffset(builder, loc, indices, |
| {smeTile.row, smeTile.col}); |
| } |
| |
| /// Returns true if `mask` is generated by an operation that can be decomposed |
| /// for SME. Currently, that is just no mask, or vector.create_mask. |
| /// TODO: Add support for vector.constant_mask once required for SME. |
| bool isSupportedMaskOp(Value mask) { |
| return !mask || mask.getDefiningOp<vector::CreateMaskOp>(); |
| } |
| |
| /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. |
| Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, |
| SMESubTile smeTile) { |
| assert(isSupportedMaskOp(mask)); |
| if (!mask) |
| return Value{}; |
| auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
| // The operands of `vector.create_mask` (from a 2D perspective) are the |
| // coordinates where the mask ends. So we subtract where this tile starts, |
| // from the mask operands to get the parameters for this sub-tile. |
| auto smeTileMaskDims = addConstantScalableOffset( |
| builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); |
| auto smeTileCreateMask = builder.create<vector::CreateMaskOp>( |
| loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); |
| return smeTileCreateMask.getResult(); |
| } |
| |
| /// Constructs an iterator that returns each SME tile (with coordinates) |
| /// contained within a VectorType. For example, if decomposing an [8]x[8] into |
| /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), |
| /// (4, 4). |
| auto decomposeToSMETiles(OpBuilder &builder, VectorType type, |
| VectorType smeTileType, |
| bool transposeIndices = false) { |
| return llvm::map_range( |
| StaticTileOffsetRange( |
| type.getShape(), |
| {std::min(type.getDimSize(0), smeTileType.getDimSize(0)), |
| std::min(type.getDimSize(1), smeTileType.getDimSize(1))}), |
| [=](auto indices) { |
| int row = int(indices[0]); |
| int col = int(indices[1]); |
| if (transposeIndices) |
| std::swap(row, col); |
| return SMESubTile{row, col, smeTileType}; |
| }); |
| } |
| |
| /// Returns the number of SME tiles that fit into the (2D-scalable) vector type |
| /// `type`. |
| int getNumberOfSMETilesForVectorType(VectorType type) { |
| assert(isMultipleOfSMETileVectorType(type) && |
| "`type` not multiple of SME tiles"); |
| int64_t vectorRows = type.getDimSize(0); |
| int64_t vectorCols = type.getDimSize(1); |
| auto elementType = type.getElementType(); |
| unsigned minNumElts = getSMETileSliceMinNumElts(elementType); |
| return (vectorRows * vectorCols) / (minNumElts * minNumElts); |
| } |
| |
| /// Legalize `arith.constant dense<value>` splat operations to fit within SME |
| /// tiles by decomposing them into tile-sized operations. |
| struct LegalizeArithConstantOpsByDecomposition |
| : public OpConversionPattern<arith::ConstantOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto vectorType = dyn_cast<VectorType>(constantOp.getType()); |
| auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); |
| if (!vectorType || !denseAttr || !denseAttr.isSplat()) |
| return failure(); |
| |
| if (!isMultipleOfSMETileVectorType(vectorType)) |
| return rewriter.notifyMatchFailure(constantOp, |
| kMatchFailureNotSMETileTypeMultiple); |
| |
| auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| auto tileCount = getNumberOfSMETilesForVectorType(vectorType); |
| auto tileSplat = rewriter.create<arith::ConstantOp>( |
| constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); |
| SmallVector<Value> repl(tileCount, tileSplat); |
| rewriter.replaceOpWithMultiple(constantOp, {repl}); |
| |
| return success(); |
| } |
| }; |
| |
| /// Legalize `vector.outerproduct` operations to fit within SME tiles by |
| /// decomposing them into tile-sized operations. |
| struct LegalizeVectorOuterProductOpsByDecomposition |
| : public OpConversionPattern<vector::OuterProductOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::OuterProductOp outerProductOp, |
| OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto vectorType = outerProductOp.getResultVectorType(); |
| if (!isMultipleOfSMETileVectorType(vectorType)) |
| return rewriter.notifyMatchFailure(outerProductOp, |
| kMatchFailureNotSMETileTypeMultiple); |
| |
| Value mask; |
| Operation *rootOp = outerProductOp; |
| auto loc = outerProductOp.getLoc(); |
| if (outerProductOp.isMasked()) { |
| auto maskOp = outerProductOp.getMaskingOp(); |
| mask = maskOp.getMask(); |
| rootOp = maskOp; |
| rewriter.setInsertionPoint(rootOp); |
| } |
| |
| if (!isSupportedMaskOp(mask)) |
| return rewriter.notifyMatchFailure(outerProductOp, |
| kMatchFailureUnsupportedMaskOp); |
| |
| ValueRange accSMETiles = adaptor.getAcc(); |
| auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0); |
| |
| SmallVector<Value> resultSMETiles; |
| for (auto [index, smeTile] : llvm::enumerate( |
| decomposeToSMETiles(rewriter, vectorType, smeTileType))) { |
| |
| auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
| auto lhs = rewriter.create<vector::ScalableExtractOp>( |
| loc, sliceType, outerProductOp.getLhs(), smeTile.row); |
| auto rhs = rewriter.create<vector::ScalableExtractOp>( |
| loc, sliceType, outerProductOp.getRhs(), smeTile.col); |
| auto smeOuterProduct = rewriter.create<vector::OuterProductOp>( |
| loc, smeTileType, lhs, rhs, |
| !accSMETiles.empty() ? accSMETiles[index] : Value{}, |
| outerProductOp.getKind()); |
| |
| auto maskedOuterProduct = |
| vector::maskOperation(rewriter, smeOuterProduct, smeMask); |
| resultSMETiles.push_back(maskedOuterProduct->getResult(0)); |
| } |
| |
| rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles}); |
| return success(); |
| } |
| }; |
| |
| // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to |
| // get the help of the type conversion), but doing so results in the type |
| // conversion adding target materializations in the `vector.mask` region |
| // (invalid). This pattern matches on `vector.mask` then calls into the |
| // `vector.outerproduct` pattern to work around this issue. |
| struct LegalizeMaskedVectorOuterProductOpsByDecomposition |
| : public OpConversionPattern<vector::MaskOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>( |
| maskOp.getMaskableOp())) { |
| LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), |
| getContext()); |
| return static_cast<RewritePattern &>(pattern).matchAndRewrite( |
| outerProductOp, rewriter); |
| } |
| return failure(); |
| } |
| }; |
| |
| /// Legalize `vector.transfer_read` operations to fit within SME tiles by |
| /// decomposing them into tile-sized operations. |
| struct LegalizeTransferReadOpsByDecomposition |
| : public OpConversionPattern<vector::TransferReadOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto vectorType = readOp.getVectorType(); |
| if (!isMultipleOfSMETileVectorType(vectorType)) |
| return rewriter.notifyMatchFailure(readOp, |
| kMatchFailureNotSMETileTypeMultiple); |
| |
| auto mask = readOp.getMask(); |
| if (!isSupportedMaskOp(mask)) |
| return rewriter.notifyMatchFailure(readOp, |
| kMatchFailureUnsupportedMaskOp); |
| |
| auto permutationMap = readOp.getPermutationMap(); |
| if (!permutationMap.isPermutation()) |
| return rewriter.notifyMatchFailure(readOp, |
| kMatchFailureNonPermutationMap); |
| |
| // Note: For 2D vector types the only non-identity permutation is a simple |
| // transpose [1, 0]. |
| bool transposed = !permutationMap.isIdentity(); |
| |
| auto loc = readOp.getLoc(); |
| auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| |
| SmallVector<Value> resultSMETiles; |
| for (SMESubTile smeTile : |
| decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { |
| auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
| auto smeRead = rewriter.create<vector::TransferReadOp>( |
| loc, smeTileType, readOp.getBase(), |
| getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), |
| readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, |
| readOp.getInBoundsAttr()); |
| resultSMETiles.push_back(smeRead); |
| } |
| |
| rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); |
| return success(); |
| } |
| }; |
| |
| /// Legalize `vector.transfer_write` operations to fit within SME tiles by |
| /// decomposing them into tile-sized operations. |
| struct LegalizeTransferWriteOpsByDecomposition |
| : public OpConversionPattern<vector::TransferWriteOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto vectorType = writeOp.getVectorType(); |
| if (!isMultipleOfSMETileVectorType(vectorType)) |
| return rewriter.notifyMatchFailure(writeOp, |
| kMatchFailureNotSMETileTypeMultiple); |
| |
| auto mask = writeOp.getMask(); |
| if (!isSupportedMaskOp(mask)) |
| return rewriter.notifyMatchFailure(writeOp, |
| kMatchFailureUnsupportedMaskOp); |
| |
| auto permutationMap = writeOp.getPermutationMap(); |
| if (!permutationMap.isPermutation()) |
| return rewriter.notifyMatchFailure(writeOp, |
| kMatchFailureNonPermutationMap); |
| |
| // Note: For 2D vector types the only non-identity permutation is a simple |
| // transpose [1, 0]. |
| bool transposed = !permutationMap.isIdentity(); |
| |
| auto loc = writeOp.getLoc(); |
| auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| auto inputSMETiles = adaptor.getValueToStore(); |
| |
| Value destTensorOrMemref = writeOp.getBase(); |
| for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( |
| rewriter, vectorType, smeTileType, transposed))) { |
| auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
| auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
| loc, inputSMETiles[index], destTensorOrMemref, |
| getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), |
| writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); |
| if (writeOp.hasPureTensorSemantics()) |
| destTensorOrMemref = smeWrite.getResult(); |
| } |
| |
| if (writeOp.hasPureTensorSemantics()) |
| rewriter.replaceOp(writeOp, destTensorOrMemref); |
| else |
| rewriter.eraseOp(writeOp); |
| |
| return success(); |
| } |
| }; |
| |
| /// Legalize a multi-tile transfer_write as a single store loop. This is done as |
| /// part of type decomposition as at this level we know each tile write is |
| /// disjoint, but that information is lost after decomposition (without analysis |
| /// to reconstruct it). |
| /// |
| /// Example (pseudo-MLIR): |
| /// |
| /// ``` |
| /// vector.transfer_write %vector, %dest[%y, %x], %mask |
| /// : vector<[16]x[8]xi16>, memref<?x?xi16> |
| /// ``` |
| /// Is rewritten to: |
| /// ``` |
| /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { |
| /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ |
| /// : vector<[8]xi1> from vector<[16]x[8]xi1> | |
| /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile |
| /// : vector<[8]xi16> from vector<[8]x[8]xi16> | |
| /// vector.transfer_write %upper_slice, | |
| /// %dest[%slice_idx + %y, %x], %upper_slice_mask | |
| /// : vector<[8]xi16>, memref<?x?xi16> ┘ |
| /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ |
| /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | |
| /// : vector<[8]xi1> from vector<[16]x[8]xi1> | |
| /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower |
| /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile |
| /// vector.transfer_write %lower_slice, | |
| /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | |
| /// : vector<[8]xi16>, memref<?x?xi16> ┘ |
| /// } |
| /// ``` |
| struct LegalizeMultiTileTransferWriteAsStoreLoop |
| : public OpConversionPattern<vector::TransferWriteOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (writeOp.hasPureTensorSemantics()) |
| return rewriter.notifyMatchFailure( |
| writeOp, "TODO: tensor semantics are unsupported"); |
| |
| auto permutationMap = writeOp.getPermutationMap(); |
| if (!permutationMap.isPermutation()) |
| return rewriter.notifyMatchFailure(writeOp, |
| kMatchFailureNonPermutationMap); |
| |
| bool transposed = !permutationMap.isIdentity(); |
| if (transposed) |
| return rewriter.notifyMatchFailure(writeOp, |
| "TODO: transpose unsupported"); |
| |
| auto vectorType = writeOp.getVectorType(); |
| if (!isMultipleOfSMETileVectorType(vectorType)) |
| return rewriter.notifyMatchFailure(writeOp, |
| kMatchFailureNotSMETileTypeMultiple); |
| |
| // Note: We also disallow masks where any dimension is > 16 because that |
| // prevents the masking from being lowered to use arm_sve.psel. |
| auto mask = writeOp.getMask(); |
| if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 || |
| vectorType.getDimSize(1) > 16))) |
| return rewriter.notifyMatchFailure(writeOp, |
| kMatchFailureUnsupportedMaskOp); |
| |
| auto loc = writeOp.getLoc(); |
| auto createVscaleMultiple = |
| vector::makeVscaleConstantBuilder(rewriter, loc); |
| |
| // Get SME tile and slice types. |
| auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| auto minTileSlices = smeTileType.getDimSize(0); |
| VectorType sliceMaskType = |
| VectorType::get(minTileSlices, rewriter.getI1Type(), true); |
| |
| // Create loop over all tile slices. |
| auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| auto upperBound = createVscaleMultiple(minTileSlices); |
| auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
| auto storeLoop = |
| rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
| rewriter.setInsertionPointToStart(storeLoop.getBody()); |
| |
| // For each sub-tile of the multi-tile `vectorType`. |
| auto inputSMETiles = adaptor.getValueToStore(); |
| auto tileSliceIndex = storeLoop.getInductionVar(); |
| for (auto [index, smeTile] : llvm::enumerate( |
| decomposeToSMETiles(rewriter, vectorType, smeTileType))) { |
| // The coordinates of the tile within `vectorType`. |
| auto tileRow = createVscaleMultiple(smeTile.row); |
| auto tileCol = createVscaleMultiple(smeTile.col); |
| |
| // The current slice of `vectorType` we are processing. |
| auto sliceIndex = |
| rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex); |
| |
| // Where in the destination memref the current slice will be stored. |
| auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex, |
| writeOp.getIndices()[0]); |
| auto storeCol = |
| rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]); |
| |
| // Extract the mask for the current slice. |
| Value sliceMask = nullptr; |
| if (mask) { |
| sliceMask = rewriter.create<vector::ExtractOp>( |
| loc, mask, OpFoldResult(sliceIndex)); |
| if (sliceMaskType != sliceMask.getType()) |
| sliceMask = rewriter.create<vector::ScalableExtractOp>( |
| loc, sliceMaskType, sliceMask, smeTile.col); |
| } |
| |
| // Extract and store the current slice. |
| Value tile = inputSMETiles[index]; |
| auto slice = |
| rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex); |
| rewriter.create<vector::TransferWriteOp>( |
| loc, slice, writeOp.getBase(), ValueRange{storeRow, storeCol}, |
| AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), |
| sliceMask, |
| rewriter.getBoolArrayAttr( |
| ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front())); |
| } |
| |
| rewriter.eraseOp(writeOp); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // ArmSME-specific fixup canonicalizations/folds |
| //===----------------------------------------------------------------------===// |
| |
| /// Folds an extract from a 3D `vector.create_mask` (which is a vector of |
| /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is |
| /// necessary for the mask to be lowered to ArmSME. |
| /// |
| /// Example: |
| /// |
| /// BEFORE: |
| /// ```mlir |
| /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> |
| /// %subMask = vector.extract %mask[2] |
| /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> |
| /// ``` |
| /// |
| /// AFTER: |
| /// ```mlir |
| /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index |
| /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index |
| /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> |
| /// ``` |
| struct FoldExtractFromVectorOfSMELikeCreateMasks |
| : public OpRewritePattern<vector::ExtractOp> { |
| using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::ExtractOp extractOp, |
| PatternRewriter &rewriter) const override { |
| auto loc = extractOp.getLoc(); |
| auto createMaskOp = |
| extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); |
| if (!createMaskOp) |
| return rewriter.notifyMatchFailure( |
| extractOp, "extract not from vector.create_mask op"); |
| |
| VectorType extractedMaskType = |
| llvm::dyn_cast<VectorType>(extractOp.getResult().getType()); |
| if (!extractedMaskType) |
| return rewriter.notifyMatchFailure(extractOp, |
| "extracted type is not a vector type"); |
| |
| auto numScalable = extractedMaskType.getNumScalableDims(); |
| if (numScalable != 2) |
| return rewriter.notifyMatchFailure( |
| extractOp, "expected extracted type to be an SME-like mask"); |
| |
| // TODO: Support multiple extraction indices. |
| if (extractOp.getStaticPosition().size() != 1) |
| return rewriter.notifyMatchFailure( |
| extractOp, "only a single extraction index is supported"); |
| |
| auto frontMaskDim = createMaskOp.getOperand(0); |
| if (frontMaskDim.getDefiningOp<arith::ConstantOp>()) |
| return rewriter.notifyMatchFailure( |
| extractOp, |
| "constant vector.create_masks dims should be folded elsewhere"); |
| |
| auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| auto extractionIndex = getValueOrCreateConstantIndexOp( |
| rewriter, loc, extractOp.getMixedPosition()[0]); |
| auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>( |
| loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, |
| frontMaskDim); |
| auto newMaskFrontDim = rewriter.create<arith::SelectOp>( |
| loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); |
| |
| rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( |
| extractOp, extractedMaskType, |
| ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)}); |
| return success(); |
| } |
| }; |
| |
| /// A vector type where no fixed dimension comes after a scalable dimension. |
| bool isLegalVectorType(VectorType vType) { |
| bool seenFixedDim = false; |
| for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) { |
| seenFixedDim |= !scalableFlag; |
| if (seenFixedDim && scalableFlag) |
| return false; |
| } |
| return true; |
| } |
| |
| /// Lifts an illegal vector.transpose and vector.transfer_read to a |
| /// memref.subview + memref.transpose, followed by a legal read. |
| /// |
| /// 'Illegal' here means a leading scalable dimension and a fixed trailing |
| /// dimension, which has no valid lowering. |
| /// |
| /// The memref.transpose is metadata-only transpose that produces a strided |
| /// memref, which eventually becomes a loop reading individual elements. |
| /// |
| /// Example: |
| /// |
| /// BEFORE: |
| /// ```mlir |
| /// %illegalRead = vector.transfer_read %memref[%a, %b] |
| /// : memref<?x?xf32>, vector<[8]x4xf32> |
| /// %legalType = vector.transpose %illegalRead, [1, 0] |
| /// : vector<[8]x4xf32> to vector<4x[8]xf32> |
| /// ``` |
| /// |
| /// AFTER: |
| /// ```mlir |
| /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] |
| /// : memref<?x?xf32> to memref<?x?xf32> |
| /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) |
| /// : memref<?x?xf32> to memref<?x?xf32> |
| /// %legalType = vector.transfer_read %transpose[%c0, %c0] |
| /// : memref<?x?xf32>, vector<4x[8]xf32> |
| /// ``` |
| struct LiftIllegalVectorTransposeToMemory |
| : public OpRewritePattern<vector::TransposeOp> { |
| using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
| |
| static Value getExtensionSource(Operation *op) { |
| if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op)) |
| return op->getOperand(0); |
| return {}; |
| } |
| |
| LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
| PatternRewriter &rewriter) const override { |
| auto sourceType = transposeOp.getSourceVectorType(); |
| auto resultType = transposeOp.getResultVectorType(); |
| if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) |
| return rewriter.notifyMatchFailure(transposeOp, |
| kMatchFailureNotIllegalToLegal); |
| |
| // Look through extend for transfer_read. |
| Value maybeRead = transposeOp.getVector(); |
| auto *transposeSourceOp = maybeRead.getDefiningOp(); |
| Operation *extendOp = nullptr; |
| if (Value extendSource = getExtensionSource(transposeSourceOp)) { |
| maybeRead = extendSource; |
| extendOp = transposeSourceOp; |
| } |
| |
| auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>(); |
| if (!illegalRead) |
| return rewriter.notifyMatchFailure( |
| transposeOp, |
| "expected source to be (possibly extended) transfer_read"); |
| |
| if (!illegalRead.getPermutationMap().isIdentity()) |
| return rewriter.notifyMatchFailure( |
| illegalRead, "expected read to have identity permutation map"); |
| |
| auto loc = transposeOp.getLoc(); |
| auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
| |
| // Create a subview that matches the size of the illegal read vector type. |
| auto readType = illegalRead.getVectorType(); |
| auto readSizes = llvm::map_to_vector( |
| llvm::zip_equal(readType.getShape(), readType.getScalableDims()), |
| [&](auto dim) -> Value { |
| auto [size, isScalable] = dim; |
| auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size); |
| if (!isScalable) |
| return dimSize; |
| auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
| return rewriter.create<arith::MulIOp>(loc, vscale, dimSize); |
| }); |
| SmallVector<Value> strides(readType.getRank(), Value(one)); |
| auto readSubview = rewriter.create<memref::SubViewOp>( |
| loc, illegalRead.getBase(), illegalRead.getIndices(), readSizes, |
| strides); |
| |
| // Apply the transpose to all values/attributes of the transfer_read: |
| // - The mask |
| Value mask = illegalRead.getMask(); |
| if (mask) { |
| // Note: The transpose for the mask should fold into the |
| // vector.create_mask/constant_mask op, which will then become legal. |
| mask = rewriter.create<vector::TransposeOp>(loc, mask, |
| transposeOp.getPermutation()); |
| } |
| // - The source memref |
| mlir::AffineMap transposeMap = AffineMap::getPermutationMap( |
| transposeOp.getPermutation(), getContext()); |
| auto transposedSubview = rewriter.create<memref::TransposeOp>( |
| loc, readSubview, AffineMapAttr::get(transposeMap)); |
| ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); |
| // - The `in_bounds` attribute |
| if (inBoundsAttr) { |
| SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(), |
| inBoundsAttr.end()); |
| applyPermutationToVector(inBoundsValues, transposeOp.getPermutation()); |
| inBoundsAttr = rewriter.getArrayAttr(inBoundsValues); |
| } |
| |
| VectorType legalReadType = resultType.clone(readType.getElementType()); |
| // Note: The indices are all zero as the subview is already offset. |
| SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero); |
| auto legalRead = rewriter.create<vector::TransferReadOp>( |
| loc, legalReadType, transposedSubview, readIndices, |
| illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, |
| inBoundsAttr); |
| |
| // Replace the transpose with the new read, extending the result if |
| // necessary. |
| rewriter.replaceOp(transposeOp, [&]() -> Operation * { |
| if (extendOp) |
| return rewriter.create(loc, extendOp->getName().getIdentifier(), |
| Value(legalRead), resultType); |
| return legalRead; |
| }()); |
| |
| return success(); |
| } |
| }; |
| |
| /// A rewrite to turn unit dim transpose-like vector.shape_casts into |
| /// vector.transposes. The shape_cast has to be from an illegal vector type to a |
| /// legal one (as defined by isLegalVectorType). |
| /// |
| /// The reasoning for this is if we've got to this pass and we still have |
| /// shape_casts of illegal types, then they likely will not cancel out. Turning |
| /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to |
| /// eliminate them. |
| /// |
| /// Example: |
| /// |
| /// BEFORE: |
| /// ```mlir |
| /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> |
| /// ``` |
| /// |
| /// AFTER: |
| /// ```mlir |
| /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> |
| /// ``` |
| struct ConvertIllegalShapeCastOpsToTransposes |
| : public OpRewritePattern<vector::ShapeCastOp> { |
| using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, |
| PatternRewriter &rewriter) const override { |
| auto sourceType = shapeCastOp.getSourceVectorType(); |
| auto resultType = shapeCastOp.getResultVectorType(); |
| if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) |
| return rewriter.notifyMatchFailure(shapeCastOp, |
| kMatchFailureNotIllegalToLegal); |
| |
| // Note: If we know that `sourceType` is an illegal vector type (and 2D) |
| // then dim 0 is scalable and dim 1 is fixed. |
| if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) |
| return rewriter.notifyMatchFailure( |
| shapeCastOp, "expected source to be a 2D scalable vector with a " |
| "trailing unit dim"); |
| |
| auto loc = shapeCastOp.getLoc(); |
| auto transpose = rewriter.create<vector::TransposeOp>( |
| loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0}); |
| |
| if (resultType.getRank() == 1) |
| rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType, |
| transpose); |
| else |
| rewriter.replaceOp(shapeCastOp, transpose); |
| |
| return success(); |
| } |
| }; |
| |
| /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use |
| /// the ZA state. This workaround rewrite to support these transposes when ZA is |
| /// available. |
| /// |
| /// Example: |
| /// |
| /// BEFORE: |
| /// ```mlir |
| /// %transpose = vector.transpose %vec, [1, 0] |
| /// : vector<2x[4]xf32> to vector<[4]x2xf32> |
| /// vector.transfer_write %transpose, %dest[%y, %x] |
| /// : vector<[4]x2xf32>, memref<?x?xf32> |
| /// ``` |
| /// |
| /// AFTER: |
| /// ```mlir |
| /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> |
| /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> |
| /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> |
| /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> |
| /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> |
| /// %c4_vscale = arith.muli %vscale, %c4 : index |
| /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> |
| /// vector.transfer_write %4, %dest[%y, %x], %mask |
| /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} |
| /// : vector<[4]x[4]xf32>, memref<?x?xf32> |
| /// ``` |
| /// |
| /// Values larger than a single tile are supported via decomposition. |
| struct LowerIllegalTransposeStoreViaZA |
| : public OpRewritePattern<vector::TransferWriteOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
| PatternRewriter &rewriter) const override { |
| if (!isSupportedMaskOp(writeOp.getMask())) |
| return rewriter.notifyMatchFailure(writeOp, |
| kMatchFailureUnsupportedMaskOp); |
| |
| auto permutationMap = writeOp.getPermutationMap(); |
| if (!permutationMap.isIdentity()) |
| return rewriter.notifyMatchFailure(writeOp, |
| kMatchFailureNonPermutationMap); |
| |
| auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>(); |
| if (!transposeOp) |
| return failure(); |
| |
| auto sourceType = transposeOp.getSourceVectorType(); |
| auto resultType = transposeOp.getResultVectorType(); |
| |
| if (resultType.getRank() != 2) |
| return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2"); |
| |
| if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType)) |
| return rewriter.notifyMatchFailure( |
| transposeOp, "not illegal/unsupported SVE transpose"); |
| |
| auto smeTileType = getSMETileTypeForElement(resultType.getElementType()); |
| VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0); |
| |
| if (sourceType.getDimSize(0) <= 1 || |
| sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0) |
| return rewriter.notifyMatchFailure(writeOp, "unsupported source shape"); |
| |
| auto loc = writeOp.getLoc(); |
| auto createVscaleMultiple = |
| vector::makeVscaleConstantBuilder(rewriter, loc); |
| |
| auto transposeMap = AffineMapAttr::get( |
| AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext())); |
| |
| // Note: We need to use `get_tile` as there's no vector-level `undef`. |
| Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType); |
| Value destTensorOrMemref = writeOp.getBase(); |
| auto numSlicesPerTile = |
| std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); |
| auto numSlices = |
| rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile); |
| for (auto [index, smeTile] : llvm::enumerate( |
| decomposeToSMETiles(rewriter, sourceType, smeTileType))) { |
| // 1. _Deliberately_ drop a scalable dimension and insert a fixed number |
| // of slices from the source type into the SME tile. Without checking |
| // vscale (and emitting multiple implementations) we can't make use of the |
| // rows of the tile after 1*vscale rows. |
| Value tile = undefTile; |
| for (int d = 0; d < numSlicesPerTile; ++d) { |
| Value vector = rewriter.create<vector::ExtractOp>( |
| loc, transposeOp.getVector(), |
| rewriter.getIndexAttr(d + smeTile.row)); |
| if (vector.getType() != smeSliceType) { |
| vector = rewriter.create<vector::ScalableExtractOp>( |
| loc, smeSliceType, vector, smeTile.col); |
| } |
| tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d); |
| } |
| |
| // 2. Transpose the tile position. |
| auto transposedRow = createVscaleMultiple(smeTile.col); |
| auto transposedCol = |
| rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row); |
| |
| // 3. Compute mask for tile store. |
| Value maskRows; |
| Value maskCols; |
| if (auto mask = writeOp.getMask()) { |
| auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
| maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0), |
| transposedRow); |
| maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1), |
| transposedCol); |
| maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices); |
| } else { |
| maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); |
| maskCols = numSlices; |
| } |
| auto subMask = rewriter.create<vector::CreateMaskOp>( |
| loc, smeTileType.clone(rewriter.getI1Type()), |
| ValueRange{maskRows, maskCols}); |
| |
| // 4. Emit a transposed tile write. |
| auto writeIndices = writeOp.getIndices(); |
| Value destRow = |
| rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]); |
| Value destCol = |
| rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]); |
| auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
| loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, |
| transposeMap, subMask, writeOp.getInBounds()); |
| |
| if (writeOp.hasPureTensorSemantics()) |
| destTensorOrMemref = smeWrite.getResult(); |
| } |
| |
| if (writeOp.hasPureTensorSemantics()) |
| rewriter.replaceOp(writeOp, destTensorOrMemref); |
| else |
| rewriter.eraseOp(writeOp); |
| |
| return success(); |
| } |
| }; |
| |
| struct VectorLegalizationPass |
| : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { |
| void runOnOperation() override { |
| auto *context = &getContext(); |
| TypeConverter converter; |
| RewritePatternSet patterns(context); |
| converter.addConversion([](Type type) { return type; }); |
| converter.addConversion( |
| [](VectorType vectorType, |
| SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { |
| if (!isMultipleOfSMETileVectorType(vectorType)) |
| return std::nullopt; |
| auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType); |
| auto smeTileType = |
| getSMETileTypeForElement(vectorType.getElementType()); |
| types = SmallVector<Type>(smeTileCount, smeTileType); |
| return success(); |
| }); |
| |
| // Apply preprocessing patterns. |
| RewritePatternSet rewritePatterns(context); |
| rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, |
| LiftIllegalVectorTransposeToMemory, |
| ConvertIllegalShapeCastOpsToTransposes, |
| LowerIllegalTransposeStoreViaZA>(context); |
| if (failed( |
| applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) |
| return signalPassFailure(); |
| |
| // Note: These two patterns are added with a high benefit to ensure: |
| // - Masked outer products are handled before unmasked ones |
| // - Multi-tile writes are lowered as a store loop (if possible) |
| patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition, |
| LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context, |
| /*benefit=*/1024); |
| patterns.add<LegalizeArithConstantOpsByDecomposition, |
| LegalizeVectorOuterProductOpsByDecomposition, |
| LegalizeTransferReadOpsByDecomposition, |
| LegalizeTransferWriteOpsByDecomposition>(converter, context); |
| populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
| converter); |
| populateCallOpTypeConversionPattern(patterns, converter); |
| populateReturnOpTypeConversionPattern(patterns, converter); |
| scf::populateSCFStructuralTypeConversions(converter, patterns); |
| |
| ConversionTarget target(getContext()); |
| target.markUnknownOpDynamicallyLegal( |
| [&](Operation *op) { return converter.isLegal(op); }); |
| target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
| return converter.isSignatureLegal(op.getFunctionType()); |
| }); |
| if (failed(applyPartialConversion(getOperation(), target, |
| std::move(patterns)))) |
| return signalPassFailure(); |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() { |
| return std::make_unique<VectorLegalizationPass>(); |
| } |