| //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements target-independent rewrites and utilities to lower the |
| // 'vector.shape_cast' operation. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/UB//IR/UBOps.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
| #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include <numeric> |
| |
| #define DEBUG_TYPE "vector-shape-cast-lowering" |
| |
| using namespace mlir; |
| |
| /// Perform the inplace update |
| /// rhs <- lhs + rhs |
| /// |
| /// where `rhs` is a number expressed in mixed base `base` with most signficant |
| /// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is |
| /// {5,3,2} then `rhs` has value a*3*2 + b*2 + c. |
| /// |
| /// Some examples where `base` is {5,3,2}: |
| /// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1} |
| /// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0} |
| /// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1} |
| /// |
| /// Invalid: |
| /// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2} |
| /// |
| /// Overflows not handled correctly: |
| /// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1}) |
| static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base, |
| MutableArrayRef<int64_t> rhs) { |
| |
| // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]: |
| for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) { |
| int64_t dimBase = base[dim]; |
| assert(rhs[dim] < dimBase && "rhs not in base"); |
| |
| int64_t incremented = rhs[dim] + lhs; |
| |
| // If the incremented value excedes the dimension base, we must spill to the |
| // next most significant dimension and repeat (we might need to spill to |
| // more significant dimensions multiple times). |
| lhs = incremented / dimBase; |
| rhs[dim] = incremented % dimBase; |
| if (lhs == 0) |
| break; |
| } |
| } |
| |
| namespace { |
| |
| /// shape_cast is converted to a sequence of extract, extract_strided_slice, |
| /// insert_strided_slice, and insert operations. The running example will be: |
| /// |
| /// %0 = vector.shape_cast %arg0 : |
| /// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8> |
| /// |
| /// In this example the source and result shapes share a common suffix of 7x11. |
| /// This means we can always decompose the shape_cast into extract, insert, and |
| /// their strided equivalents, on vectors with shape suffix 7x11. |
| /// |
| /// The greatest common divisor (gcd) of the first dimension preceding the |
| /// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate |
| /// on vectors with shapes that are `multiples` of (what we define as) the |
| /// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`. |
| /// |
| /// vector<2x2x3x4x7x11xi8> to |
| /// vector<8x6x7x11xi8> |
| /// | |||| |
| /// | ++++------------> common suffix of 7x11 |
| /// +-----------------> gcd(4,6) is 2 | | |
| /// | | | |
| /// v v v |
| /// atomic shape <----- 2x7x11 |
| /// |
| /// |
| /// |
| /// The decomposition implemented in this pattern consists of a sequence of |
| /// repeated steps: |
| /// |
| /// (1) Extract vectors from the suffix of the source. |
| /// In our example this is 2x2x3x4x7x11 -> 4x7x11. |
| /// |
| /// (2) Do extract_strided_slice down to the atomic shape. |
| /// In our example this is 4x7x11 -> 2x7x11. |
| /// |
| /// (3) Do insert_strided_slice to the suffix of the result. |
| /// In our example this is 2x7x11 -> 6x7x11. |
| /// |
| /// (4) insert these vectors into the result vector. |
| /// In our example this is 6x7x11 -> 8x6x7x11. |
| /// |
| /// These steps occur with different periods. In this example |
| /// (1) occurs 12 times, |
| /// (2) and (3) occur 24 times, and |
| /// (4) occurs 8 times. |
| /// |
| /// Two special cases are handled independently in this pattern |
| /// (i) A shape_cast that just does leading 1 insertion/removal |
| /// (ii) A shape_cast where the gcd is 1. |
| /// |
| /// These 2 cases can have more compact IR generated by not using the generic |
| /// algorithm described above. |
| /// |
| class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { |
| |
| // Case (i) of description. |
| // Assumes source and result shapes are identical up to some leading ones. |
| static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast, |
| PatternRewriter &rewriter) { |
| |
| const Location loc = shapeCast.getLoc(); |
| const VectorType sourceType = shapeCast.getSourceVectorType(); |
| const VectorType resultType = shapeCast.getResultVectorType(); |
| |
| const int64_t sourceRank = sourceType.getRank(); |
| const int64_t resultRank = resultType.getRank(); |
| const int64_t delta = sourceRank - resultRank; |
| const int64_t sourceLeading = delta > 0 ? delta : 0; |
| const int64_t resultLeading = delta > 0 ? 0 : -delta; |
| |
| const Value source = shapeCast.getSource(); |
| const Value poison = ub::PoisonOp::create(rewriter, loc, resultType); |
| const Value extracted = vector::ExtractOp::create( |
| rewriter, loc, source, SmallVector<int64_t>(sourceLeading, 0)); |
| const Value result = |
| vector::InsertOp::create(rewriter, loc, extracted, poison, |
| SmallVector<int64_t>(resultLeading, 0)); |
| |
| rewriter.replaceOp(shapeCast, result); |
| return success(); |
| } |
| |
| // Case (ii) of description. |
| // Assumes a shape_cast where the suffix shape of the source starting at |
| // `sourceDim` and the suffix shape of the result starting at `resultDim` are |
| // identical. |
| static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast, |
| int64_t sourceDim, |
| int64_t resultDim, |
| PatternRewriter &rewriter) { |
| |
| const Location loc = shapeCast.getLoc(); |
| |
| const Value source = shapeCast.getSource(); |
| const ArrayRef<int64_t> sourceShape = |
| shapeCast.getSourceVectorType().getShape(); |
| |
| const VectorType resultType = shapeCast.getResultVectorType(); |
| const ArrayRef<int64_t> resultShape = resultType.getShape(); |
| |
| const int64_t nSlices = |
| std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1, |
| std::multiplies<int64_t>()); |
| |
| SmallVector<int64_t> extractIndex(sourceDim, 0); |
| SmallVector<int64_t> insertIndex(resultDim, 0); |
| Value result = ub::PoisonOp::create(rewriter, loc, resultType); |
| |
| for (int i = 0; i < nSlices; ++i) { |
| Value extracted = |
| vector::ExtractOp::create(rewriter, loc, source, extractIndex); |
| |
| result = vector::InsertOp::create(rewriter, loc, extracted, result, |
| insertIndex); |
| |
| inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex); |
| inplaceAdd(1, resultShape.take_front(resultDim), insertIndex); |
| } |
| rewriter.replaceOp(shapeCast, result); |
| return success(); |
| } |
| |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| VectorType sourceType = op.getSourceVectorType(); |
| VectorType resultType = op.getResultVectorType(); |
| |
| if (sourceType.isScalable() || resultType.isScalable()) |
| return rewriter.notifyMatchFailure( |
| op, |
| "shape_cast where vectors are scalable not handled by this pattern"); |
| |
| const ArrayRef<int64_t> sourceShape = sourceType.getShape(); |
| const ArrayRef<int64_t> resultShape = resultType.getShape(); |
| const int64_t sourceRank = sourceType.getRank(); |
| const int64_t resultRank = resultType.getRank(); |
| const int64_t numElms = sourceType.getNumElements(); |
| const Value source = op.getSource(); |
| |
| // Set the first dimension (starting at the end) in the source and result |
| // respectively where the dimension sizes differ. Using the running example: |
| // |
| // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ] |
| // shapes: (2,2,3,4,7,11) -> (8,6,7,11) |
| // ^ ^ |
| // | | |
| // sourceSuffixStartDim is 3 | |
| // | |
| // resultSuffixStartDim is 1 |
| int64_t sourceSuffixStartDim = sourceRank - 1; |
| int64_t resultSuffixStartDim = resultRank - 1; |
| while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 && |
| (sourceType.getDimSize(sourceSuffixStartDim) == |
| resultType.getDimSize(resultSuffixStartDim))) { |
| --sourceSuffixStartDim; |
| --resultSuffixStartDim; |
| } |
| |
| // This is the case (i) where there are just some leading ones to contend |
| // with in the source or result. It can be handled with a single |
| // extract/insert pair. |
| if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) |
| return leadingOnesLowering(op, rewriter); |
| |
| const int64_t sourceSuffixStartDimSize = |
| sourceType.getDimSize(sourceSuffixStartDim); |
| const int64_t resultSuffixStartDimSize = |
| resultType.getDimSize(resultSuffixStartDim); |
| const int64_t greatestCommonDivisor = |
| std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize); |
| const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim; |
| const size_t extractPeriod = |
| sourceSuffixStartDimSize / greatestCommonDivisor; |
| const size_t insertPeriod = |
| resultSuffixStartDimSize / greatestCommonDivisor; |
| |
| SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim, |
| sourceShape.end()); |
| atomicShape[0] = greatestCommonDivisor; |
| |
| const int64_t numAtomicElms = std::accumulate( |
| atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>()); |
| const size_t nAtomicSlices = numElms / numAtomicElms; |
| |
| // This is the case (ii) where the strided dimension size is 1. More compact |
| // IR is generated in this case if we just extract and insert the elements |
| // directly. In other words, we don't use extract_strided_slice and |
| // insert_strided_slice. |
| if (greatestCommonDivisor == 1) |
| return noStridedSliceLowering(op, sourceSuffixStartDim + 1, |
| resultSuffixStartDim + 1, rewriter); |
| |
| // The insert_strided_slice result's type |
| const ArrayRef<int64_t> insertStridedShape = |
| resultShape.drop_front(resultSuffixStartDim); |
| const VectorType insertStridedType = |
| VectorType::get(insertStridedShape, resultType.getElementType()); |
| |
| SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0); |
| SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0); |
| SmallVector<int64_t> extractOffsets(stridedSliceRank, 0); |
| SmallVector<int64_t> insertOffsets(stridedSliceRank, 0); |
| const SmallVector<int64_t> sizes(stridedSliceRank, 1); |
| |
| Value extracted = {}; |
| Value extractedStrided = {}; |
| Value insertedSlice = {}; |
| Value result = ub::PoisonOp::create(rewriter, loc, resultType); |
| const Value partResult = |
| ub::PoisonOp::create(rewriter, loc, insertStridedType); |
| |
| for (size_t i = 0; i < nAtomicSlices; ++i) { |
| |
| const size_t extractStridedPhase = i % extractPeriod; |
| const size_t insertStridedPhase = i % insertPeriod; |
| |
| // vector.extract |
| if (extractStridedPhase == 0) { |
| extracted = |
| vector::ExtractOp::create(rewriter, loc, source, extractIndex); |
| inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim), |
| extractIndex); |
| } |
| |
| // vector.extract_strided_slice |
| extractOffsets[0] = extractStridedPhase * greatestCommonDivisor; |
| extractedStrided = vector::ExtractStridedSliceOp::create( |
| rewriter, loc, extracted, extractOffsets, atomicShape, sizes); |
| |
| // vector.insert_strided_slice |
| if (insertStridedPhase == 0) { |
| insertedSlice = partResult; |
| } |
| insertOffsets[0] = insertStridedPhase * greatestCommonDivisor; |
| insertedSlice = vector::InsertStridedSliceOp::create( |
| rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes); |
| |
| // vector.insert |
| if (insertStridedPhase + 1 == insertPeriod) { |
| result = vector::InsertOp::create(rewriter, loc, insertedSlice, result, |
| insertIndex); |
| inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim), |
| insertIndex); |
| } |
| } |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| /// A shape_cast lowering for scalable vectors with a single trailing scalable |
| /// dimension. This is similar to the general shape_cast lowering but makes use |
| /// of vector.scalable.insert and vector.scalable.extract to move elements a |
| /// subvector at a time. |
| /// |
| /// E.g.: |
| /// ``` |
| /// // Flatten scalable vector |
| /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> |
| /// ``` |
| /// is rewritten to: |
| /// ``` |
| /// // Flatten scalable vector |
| /// %c = arith.constant dense<0> : vector<[8]xi32> |
| /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> |
| /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> |
| /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> |
| /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> |
| /// ``` |
| /// or: |
| /// ``` |
| /// // Un-flatten scalable vector |
| /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> |
| /// ``` |
| /// is rewritten to: |
| /// ``` |
| /// // Un-flatten scalable vector |
| /// %c = arith.constant dense<0> : vector<2x1x[4]xi32> |
| /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> |
| /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> |
| /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> |
| /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> |
| /// ``` |
| class ScalableShapeCastOpRewritePattern |
| : public OpRewritePattern<vector::ShapeCastOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
| PatternRewriter &rewriter) const override { |
| |
| Location loc = op.getLoc(); |
| auto sourceVectorType = op.getSourceVectorType(); |
| auto resultVectorType = op.getResultVectorType(); |
| auto srcRank = sourceVectorType.getRank(); |
| auto resRank = resultVectorType.getRank(); |
| |
| // This can only lower shape_casts where both the source and result types |
| // have a single trailing scalable dimension. This is because there are no |
| // legal representation of other scalable types in LLVM (and likely won't be |
| // soon). There are also (currently) no operations that can index or extract |
| // from >= 2-D scalable vectors or scalable vectors of fixed vectors. |
| if (!isTrailingDimScalable(sourceVectorType) || |
| !isTrailingDimScalable(resultVectorType)) { |
| return rewriter.notifyMatchFailure( |
| op, "trailing dims are not scalable, not handled by this pattern"); |
| } |
| |
| // The sizes of the trailing dimension of the source and result vectors, the |
| // size of subvector to move, and the number of elements in the vectors. |
| // These are "min" sizes as they are the size when vscale == 1. |
| auto minSourceTrailingSize = sourceVectorType.getShape().back(); |
| auto minResultTrailingSize = resultVectorType.getShape().back(); |
| auto minExtractionSize = |
| std::min(minSourceTrailingSize, minResultTrailingSize); |
| int64_t minNumElts = 1; |
| for (auto size : sourceVectorType.getShape()) |
| minNumElts *= size; |
| |
| // The subvector type to move from the source to the result. Note that this |
| // is a scalable vector. This rewrite will generate code in terms of the |
| // "min" size (vscale == 1 case), that scales to any vscale. |
| auto extractionVectorType = VectorType::get( |
| {minExtractionSize}, sourceVectorType.getElementType(), {true}); |
| |
| Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType); |
| SmallVector<int64_t> srcIdx(srcRank, 0); |
| SmallVector<int64_t> resIdx(resRank, 0); |
| |
| // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils) |
| // once D150000 lands. |
| Value currentResultScalableVector; |
| Value currentSourceScalableVector; |
| for (int64_t i = 0; i < minNumElts; i += minExtractionSize) { |
| // 1. Extract a scalable subvector from the source vector. |
| if (!currentSourceScalableVector) { |
| if (srcRank != 1) { |
| currentSourceScalableVector = |
| vector::ExtractOp::create(rewriter, loc, op.getSource(), |
| llvm::ArrayRef(srcIdx).drop_back()); |
| } else { |
| currentSourceScalableVector = op.getSource(); |
| } |
| } |
| Value sourceSubVector = currentSourceScalableVector; |
| if (minExtractionSize < minSourceTrailingSize) { |
| sourceSubVector = vector::ScalableExtractOp::create( |
| rewriter, loc, extractionVectorType, sourceSubVector, |
| srcIdx.back()); |
| } |
| |
| // 2. Insert the scalable subvector into the result vector. |
| if (!currentResultScalableVector) { |
| if (minExtractionSize == minResultTrailingSize) { |
| currentResultScalableVector = sourceSubVector; |
| } else if (resRank != 1) { |
| currentResultScalableVector = vector::ExtractOp::create( |
| rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back()); |
| } else { |
| currentResultScalableVector = result; |
| } |
| } |
| if (minExtractionSize < minResultTrailingSize) { |
| currentResultScalableVector = vector::ScalableInsertOp::create( |
| rewriter, loc, sourceSubVector, currentResultScalableVector, |
| resIdx.back()); |
| } |
| |
| // 3. Update the source and result scalable vectors if needed. |
| if (resIdx.back() + minExtractionSize >= minResultTrailingSize && |
| currentResultScalableVector != result) { |
| // Finished row of result. Insert complete scalable vector into result |
| // (n-D) vector. |
| result = vector::InsertOp::create(rewriter, loc, |
| currentResultScalableVector, result, |
| llvm::ArrayRef(resIdx).drop_back()); |
| currentResultScalableVector = {}; |
| } |
| if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { |
| // Finished row of source. |
| currentSourceScalableVector = {}; |
| } |
| |
| // 4. Increment the insert/extract indices, stepping by minExtractionSize |
| // for the trailing dimensions. |
| inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx); |
| inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx); |
| } |
| |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| |
| static bool isTrailingDimScalable(VectorType type) { |
| return type.getRank() >= 1 && type.getScalableDims().back() && |
| !llvm::is_contained(type.getScalableDims().drop_back(), true); |
| } |
| }; |
| |
| } // namespace |
| |
| void mlir::vector::populateVectorShapeCastLoweringPatterns( |
| RewritePatternSet &patterns, PatternBenefit benefit) { |
| patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>( |
| patterns.getContext(), benefit); |
| } |