| //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements patterns to do vector unrolling and vector distribution. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| #include "mlir/Interfaces/VectorInterfaces.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/DebugLog.h" |
| #include "llvm/Support/InterleavedRange.h" |
| #include <optional> |
| |
| #define DEBUG_TYPE "vector-unroll" |
| |
| using namespace mlir; |
| using namespace mlir::vector; |
| |
| /// Compute the indices of the slice `index` for a transfer op. |
| static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, |
| ArrayRef<Value> indices, |
| AffineMap permutationMap, |
| Location loc, |
| OpBuilder &builder) { |
| MLIRContext *ctx = builder.getContext(); |
| auto isBroadcast = [](AffineExpr expr) { |
| if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) |
| return constExpr.getValue() == 0; |
| return false; |
| }; |
| // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. |
| SmallVector<Value> slicedIndices(indices); |
| for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { |
| if (isBroadcast(dim.value())) |
| continue; |
| unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition(); |
| auto expr = getAffineDimExpr(0, builder.getContext()) + |
| getAffineConstantExpr(elementOffsets[dim.index()], ctx); |
| auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); |
| slicedIndices[pos] = |
| affine::AffineApplyOp::create(builder, loc, map, indices[pos]); |
| } |
| return slicedIndices; |
| } |
| |
| // Compute the new indices by adding `offsets` to `originalIndices`. |
| // If m < n (m = offsets.size(), n = originalIndices.size()), |
| // then only the trailing m values in `originalIndices` are updated. |
| static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter, |
| Location loc, |
| OperandRange originalIndices, |
| ArrayRef<int64_t> offsets) { |
| assert(offsets.size() <= originalIndices.size() && |
| "Offsets should not exceed the number of original indices"); |
| SmallVector<Value> indices(originalIndices); |
| |
| auto start = indices.size() - offsets.size(); |
| for (auto [i, offset] : llvm::enumerate(offsets)) { |
| if (offset != 0) { |
| indices[start + i] = arith::AddIOp::create( |
| rewriter, loc, originalIndices[start + i], |
| arith::ConstantIndexOp::create(rewriter, loc, offset)); |
| } |
| } |
| return indices; |
| } |
| |
| // Clones `op` into a new operations that takes `operands` and returns |
| // `resultTypes`. |
| static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, |
| Operation *op, |
| ArrayRef<Value> operands, |
| ArrayRef<Type> resultTypes) { |
| return builder.create(loc, op->getName().getIdentifier(), operands, |
| resultTypes, op->getAttrs()); |
| } |
| |
| /// Return the target shape for unrolling for the given `op`. Return |
| /// std::nullopt if the op shouldn't be or cannot be unrolled. |
| static std::optional<SmallVector<int64_t>> |
| getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { |
| LDBG() << "Get unroll shape for op " << op->getName().getStringRef(); |
| if (options.filterConstraint && failed(options.filterConstraint(op))) { |
| LDBG() << "--no filter constraint -> BAIL"; |
| return std::nullopt; |
| } |
| assert(options.nativeShape && |
| "vector unrolling expects the native shape or native" |
| "shape call back function to be set"); |
| auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); |
| if (!unrollableVectorOp) { |
| LDBG() << "--not an unrollable op -> BAIL"; |
| return std::nullopt; |
| } |
| auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); |
| if (!maybeUnrollShape) { |
| LDBG() << "--could not get shape of op " << *op << " -> BAIL"; |
| return std::nullopt; |
| } |
| LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape); |
| |
| std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); |
| if (!targetShape) { |
| LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP"; |
| return std::nullopt; |
| } |
| LDBG() << "--target shape: " << llvm::interleaved(*targetShape); |
| |
| auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); |
| if (!maybeShapeRatio) { |
| LDBG() << "--could not compute integral shape ratio -> BAIL"; |
| return std::nullopt; |
| } |
| if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { |
| LDBG() << "--no unrolling needed -> SKIP"; |
| return std::nullopt; |
| } |
| LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS"; |
| return targetShape; |
| } |
| |
| static SmallVector<int64_t> |
| getUnrollOrder(unsigned numLoops, Operation *op, |
| const vector::UnrollVectorOptions &options) { |
| SmallVector<int64_t> loopOrder = |
| llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops))); |
| if (options.traversalOrderCallback != nullptr) { |
| std::optional<SmallVector<int64_t>> order = |
| options.traversalOrderCallback(op); |
| if (order) { |
| loopOrder = std::move(*order); |
| } |
| } |
| return loopOrder; |
| } |
| |
| namespace { |
| |
| struct UnrollTransferReadPattern |
| : public OpRewritePattern<vector::TransferReadOp> { |
| UnrollTransferReadPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
| PatternRewriter &rewriter) const override { |
| // TODO: support 0-d corner case. |
| if (readOp.getTransferRank() == 0) |
| return failure(); |
| if (readOp.getMask()) |
| return failure(); |
| auto targetShape = getTargetShape(options, readOp); |
| if (!targetShape) |
| return failure(); |
| auto sourceVectorType = readOp.getVectorType(); |
| SmallVector<int64_t> strides(targetShape->size(), 1); |
| Location loc = readOp.getLoc(); |
| ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); |
| |
| // Prepare the result vector; |
| Value result = |
| arith::ConstantOp::create(rewriter, loc, sourceVectorType, |
| rewriter.getZeroAttr(sourceVectorType)); |
| auto targetType = |
| VectorType::get(*targetShape, sourceVectorType.getElementType()); |
| SmallVector<Value> originalIndices(readOp.getIndices().begin(), |
| readOp.getIndices().end()); |
| SmallVector<int64_t> loopOrder = |
| getUnrollOrder(originalSize.size(), readOp, options); |
| for (SmallVector<int64_t> elementOffsets : |
| StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| SmallVector<Value> indices = |
| sliceTransferIndices(elementOffsets, originalIndices, |
| readOp.getPermutationMap(), loc, rewriter); |
| auto slicedRead = vector::TransferReadOp::create( |
| rewriter, loc, targetType, readOp.getBase(), indices, |
| readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), |
| readOp.getInBoundsAttr()); |
| |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, slicedRead, result, elementOffsets, strides); |
| } |
| rewriter.replaceOp(readOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollTransferWritePattern |
| : public OpRewritePattern<vector::TransferWriteOp> { |
| UnrollTransferWritePattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
| PatternRewriter &rewriter) const override { |
| // TODO: support 0-d corner case. |
| if (writeOp.getTransferRank() == 0) |
| return failure(); |
| |
| if (writeOp.getMask()) |
| return failure(); |
| auto targetShape = getTargetShape(options, writeOp); |
| if (!targetShape) |
| return failure(); |
| auto sourceVectorType = writeOp.getVectorType(); |
| SmallVector<int64_t> strides(targetShape->size(), 1); |
| Location loc = writeOp.getLoc(); |
| ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); |
| // Bail-out if rank(source) != rank(target). The main limitation here is the |
| // fact that `ExtractStridedSlice` requires the rank for the input and |
| // output to match. If needed, we can relax this later. |
| if (originalSize.size() != targetShape->size()) |
| return rewriter.notifyMatchFailure( |
| writeOp, |
| "expected source input vector rank to match target shape rank"); |
| |
| SmallVector<Value> originalIndices(writeOp.getIndices().begin(), |
| writeOp.getIndices().end()); |
| SmallVector<int64_t> loopOrder = |
| getUnrollOrder(originalSize.size(), writeOp, options); |
| Value resultTensor; |
| for (SmallVector<int64_t> elementOffsets : |
| StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, writeOp.getVector(), elementOffsets, *targetShape, strides); |
| SmallVector<Value> indices = |
| sliceTransferIndices(elementOffsets, originalIndices, |
| writeOp.getPermutationMap(), loc, rewriter); |
| Operation *slicedWrite = vector::TransferWriteOp::create( |
| rewriter, loc, slicedVector, |
| resultTensor ? resultTensor : writeOp.getBase(), indices, |
| writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); |
| // For the tensor case update the destination for the next transfer write. |
| if (!slicedWrite->getResults().empty()) |
| resultTensor = slicedWrite->getResult(0); |
| } |
| if (resultTensor) |
| rewriter.replaceOp(writeOp, resultTensor); |
| else |
| rewriter.eraseOp(writeOp); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct OffsetMapInfo { |
| static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } |
| |
| static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } |
| |
| static unsigned getHashValue(const SmallVector<int64_t> &v) { |
| return static_cast<unsigned>(llvm::hash_combine_range(v)); |
| } |
| |
| static bool isEqual(const SmallVector<int64_t> &lhs, |
| const SmallVector<int64_t> &rhs) { |
| return lhs == rhs; |
| } |
| }; |
| |
| struct UnrollContractionPattern |
| : public OpRewritePattern<vector::ContractionOp> { |
| UnrollContractionPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::ContractionOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
| PatternRewriter &rewriter) const override { |
| auto targetShape = getTargetShape(options, contractOp); |
| if (!targetShape) |
| return failure(); |
| auto dstVecType = cast<VectorType>(contractOp.getResultType()); |
| SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll(); |
| |
| Location loc = contractOp.getLoc(); |
| unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); |
| AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; |
| llvm::MapVector< |
| SmallVector<int64_t>, Value, |
| llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
| accCache; |
| |
| SmallVector<int64_t> loopOrder = getUnrollOrder( |
| contractOp.getIteratorTypes().size(), contractOp, options); |
| |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| SmallVector<Value> slicesOperands(contractOp.getNumOperands()); |
| |
| // Helper to compute the new shape of each operand and extract the slice. |
| auto extractOperand = [&](unsigned index, Value operand, |
| AffineMap permutationMap, |
| ArrayRef<int64_t> operandOffets) { |
| SmallVector<int64_t> operandShape = applyPermutationMap( |
| permutationMap, ArrayRef<int64_t>(*targetShape)); |
| SmallVector<int64_t> operandStrides(operandOffets.size(), 1); |
| slicesOperands[index] = |
| rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, operand, operandOffets, operandShape, operandStrides); |
| }; |
| |
| // Extract the new lhs operand. |
| AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; |
| SmallVector<int64_t> lhsOffets = |
| applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); |
| extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); |
| |
| // Extract the new rhs operand. |
| AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; |
| SmallVector<int64_t> rhsOffets = |
| applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); |
| extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); |
| |
| AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; |
| SmallVector<int64_t> accOffets = |
| applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); |
| // If a version of the accumulator has already been computed, use it |
| // otherwise extract the first version from the original operand. |
| auto *accIt = accCache.find(accOffets); |
| if (accIt != accCache.end()) |
| slicesOperands[2] = accIt->second; |
| else |
| extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); |
| |
| SmallVector<int64_t> dstShape = |
| applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); |
| auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); |
| Operation *newOp = cloneOpWithOperandsAndTypes( |
| rewriter, loc, contractOp, slicesOperands, targetType); |
| |
| SmallVector<int64_t> dstOffets = |
| applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); |
| // Save the accumulated value untill all the loops are unrolled since |
| // reduction loop keep updating the accumulator. |
| accCache[dstOffets] = newOp->getResult(0); |
| } |
| // Assemble back the accumulator into a single vector. |
| Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, |
| rewriter.getZeroAttr(dstVecType)); |
| for (const auto &it : accCache) { |
| SmallVector<int64_t> dstStrides(it.first.size(), 1); |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, it.second, result, it.first, dstStrides); |
| } |
| rewriter.replaceOp(contractOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollMultiReductionPattern |
| : public OpRewritePattern<vector::MultiDimReductionOp> { |
| UnrollMultiReductionPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, |
| PatternRewriter &rewriter) const override { |
| std::optional<SmallVector<int64_t>> targetShape = |
| getTargetShape(options, reductionOp); |
| if (!targetShape) |
| return failure(); |
| SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
| Location loc = reductionOp.getLoc(); |
| auto resultType = reductionOp->getResult(0).getType(); |
| |
| // Handle scalar result case: all dimensions are reduced. |
| // Each source tile is reduced to a scalar, and partial results are |
| // chained through the accumulator operand. |
| if (resultType.isIntOrFloat()) { |
| Value accumulator = reductionOp.getAcc(); |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalSize, *targetShape)) { |
| SmallVector<int64_t> operandStrides(offsets.size(), 1); |
| Value slicedOperand = |
| rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, reductionOp.getSource(), offsets, *targetShape, |
| operandStrides); |
| Operation *newOp = cloneOpWithOperandsAndTypes( |
| rewriter, loc, reductionOp, {slicedOperand, accumulator}, |
| resultType); |
| accumulator = newOp->getResult(0); |
| } |
| rewriter.replaceOp(reductionOp, accumulator); |
| return success(); |
| } |
| |
| // Vector result case. |
| llvm::MapVector< |
| SmallVector<int64_t>, Value, |
| llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
| accCache; |
| |
| // Stride of the ratios, this gives us the offsets of sliceCount in a basis |
| // of multiples of the targetShape. |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalSize, *targetShape)) { |
| SmallVector<Value> operands; |
| SmallVector<int64_t> operandStrides(offsets.size(), 1); |
| Value slicedOperand = |
| rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, reductionOp.getSource(), offsets, *targetShape, |
| operandStrides); |
| operands.push_back(slicedOperand); |
| SmallVector<int64_t> dstShape; |
| SmallVector<int64_t> destOffset; |
| for (size_t i : llvm::seq(size_t(0), targetShape->size())) { |
| if (!reductionOp.isReducedDim(i)) { |
| destOffset.push_back(offsets[i]); |
| dstShape.push_back((*targetShape)[i]); |
| } |
| } |
| Value acc; |
| SmallVector<int64_t> accStrides(destOffset.size(), 1); |
| // If a version of the accumulator has already been computed, use it |
| // otherwise extract the first version from the original operand. |
| auto *accIt = accCache.find(destOffset); |
| if (accIt != accCache.end()) |
| acc = accIt->second; |
| else |
| acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); |
| operands.push_back(acc); |
| auto targetType = VectorType::get( |
| dstShape, reductionOp.getSourceVectorType().getElementType()); |
| Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, |
| operands, targetType); |
| Value result = newOp->getResult(0); |
| accCache[destOffset] = result; |
| } |
| // Assemble back the accumulator into a single vector. |
| Value result = arith::ConstantOp::create( |
| rewriter, loc, reductionOp.getDestType(), |
| rewriter.getZeroAttr(reductionOp.getDestType())); |
| for (const auto &it : accCache) { |
| SmallVector<int64_t> dstStrides(it.first.size(), 1); |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, it.second, result, it.first, dstStrides); |
| } |
| rewriter.replaceOp(reductionOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollElementwisePattern : public RewritePattern { |
| UnrollElementwisePattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : RewritePattern(MatchAnyOpTypeTag(), benefit, context), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) |
| return failure(); |
| auto targetShape = getTargetShape(options, op); |
| if (!targetShape) |
| return failure(); |
| int64_t targetShapeRank = targetShape->size(); |
| auto dstVecType = cast<VectorType>(op->getResult(0).getType()); |
| SmallVector<int64_t> originalSize = |
| *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); |
| int64_t originalShapeRank = originalSize.size(); |
| |
| Location loc = op->getLoc(); |
| |
| // Handle rank mismatch by adding leading unit dimensions to targetShape |
| SmallVector<int64_t> adjustedTargetShape(originalShapeRank); |
| int64_t rankDiff = originalShapeRank - targetShapeRank; |
| std::fill(adjustedTargetShape.begin(), |
| adjustedTargetShape.begin() + rankDiff, 1); |
| std::copy(targetShape->begin(), targetShape->end(), |
| adjustedTargetShape.begin() + rankDiff); |
| |
| int64_t adjustedTargetShapeRank = adjustedTargetShape.size(); |
| // Prepare the result vector. |
| Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, |
| rewriter.getZeroAttr(dstVecType)); |
| SmallVector<int64_t> strides(adjustedTargetShapeRank, 1); |
| VectorType unrolledVecType = |
| VectorType::get(*targetShape, dstVecType.getElementType()); |
| |
| // Create the unrolled computation. |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalSize, adjustedTargetShape)) { |
| SmallVector<Value> extractOperands; |
| for (OpOperand &operand : op->getOpOperands()) { |
| auto vecType = dyn_cast<VectorType>(operand.get().getType()); |
| if (!vecType) { |
| extractOperands.push_back(operand.get()); |
| continue; |
| } |
| Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, operand.get(), offsets, adjustedTargetShape, strides); |
| |
| // Reshape to remove leading unit dims if needed |
| if (adjustedTargetShapeRank > targetShapeRank) { |
| extracted = rewriter.createOrFold<vector::ShapeCastOp>( |
| loc, VectorType::get(*targetShape, vecType.getElementType()), |
| extracted); |
| } |
| extractOperands.push_back(extracted); |
| } |
| |
| Operation *newOp = cloneOpWithOperandsAndTypes( |
| rewriter, loc, op, extractOperands, unrolledVecType); |
| |
| Value computeResult = newOp->getResult(0); |
| |
| // Use strides sized to targetShape for proper insertion |
| SmallVector<int64_t> insertStrides = |
| (adjustedTargetShapeRank > targetShapeRank) |
| ? SmallVector<int64_t>(targetShapeRank, 1) |
| : strides; |
| |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, computeResult, result, offsets, insertStrides); |
| } |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { |
| UnrollReductionPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::ReductionOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, |
| PatternRewriter &rewriter) const override { |
| std::optional<SmallVector<int64_t>> targetShape = |
| getTargetShape(options, reductionOp); |
| if (!targetShape) |
| return failure(); |
| SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
| |
| // Create unrolled vector reduction. |
| Location loc = reductionOp.getLoc(); |
| Value accumulator = nullptr; |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalSize, *targetShape)) { |
| SmallVector<int64_t> strides(offsets.size(), 1); |
| Value slicedOperand = |
| rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, reductionOp.getVector(), offsets, *targetShape, strides); |
| Operation *newOp = cloneOpWithOperandsAndTypes( |
| rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); |
| Value result = newOp->getResult(0); |
| |
| if (!accumulator) { |
| // This is the first reduction. |
| accumulator = result; |
| } else { |
| // On subsequent reduction, combine with the accumulator. |
| accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), |
| accumulator, result); |
| } |
| } |
| |
| rewriter.replaceOp(reductionOp, accumulator); |
| return success(); |
| } |
| |
| private: |
| const vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> { |
| UnrollTransposePattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::TransposeOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
| PatternRewriter &rewriter) const override { |
| if (transposeOp.getResultVectorType().getRank() == 0) |
| return failure(); |
| auto targetShape = getTargetShape(options, transposeOp); |
| if (!targetShape) |
| return failure(); |
| auto originalVectorType = transposeOp.getResultVectorType(); |
| SmallVector<int64_t> strides(targetShape->size(), 1); |
| Location loc = transposeOp.getLoc(); |
| ArrayRef<int64_t> originalSize = originalVectorType.getShape(); |
| |
| // Prepare the result vector; |
| Value result = |
| arith::ConstantOp::create(rewriter, loc, originalVectorType, |
| rewriter.getZeroAttr(originalVectorType)); |
| ArrayRef<int64_t> permutation = transposeOp.getPermutation(); |
| |
| // Unroll the computation. |
| for (SmallVector<int64_t> elementOffsets : |
| StaticTileOffsetRange(originalSize, *targetShape)) { |
| SmallVector<int64_t> permutedOffsets(elementOffsets.size()); |
| SmallVector<int64_t> permutedShape(elementOffsets.size()); |
| // Compute the source offsets and shape. |
| for (auto indices : llvm::enumerate(permutation)) { |
| permutedOffsets[indices.value()] = elementOffsets[indices.index()]; |
| permutedShape[indices.value()] = (*targetShape)[indices.index()]; |
| } |
| Value slicedOperand = |
| rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, transposeOp.getVector(), permutedOffsets, permutedShape, |
| strides); |
| Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>( |
| loc, slicedOperand, permutation); |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, transposedSlice, result, elementOffsets, strides); |
| } |
| rewriter.replaceOp(transposeOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { |
| UnrollGatherPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) { |
| } |
| |
| LogicalResult matchAndRewrite(vector::GatherOp gatherOp, |
| PatternRewriter &rewriter) const override { |
| VectorType sourceVectorType = gatherOp.getVectorType(); |
| if (sourceVectorType.getRank() == 0) |
| return failure(); |
| auto targetShape = getTargetShape(options, gatherOp); |
| if (!targetShape) |
| return failure(); |
| SmallVector<int64_t> strides(targetShape->size(), 1); |
| Location loc = gatherOp.getLoc(); |
| ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape(); |
| |
| // Prepare the result vector; |
| Value result = |
| arith::ConstantOp::create(rewriter, loc, sourceVectorType, |
| rewriter.getZeroAttr(sourceVectorType)); |
| auto targetType = |
| VectorType::get(*targetShape, sourceVectorType.getElementType()); |
| |
| SmallVector<int64_t> loopOrder = |
| getUnrollOrder(originalSize.size(), gatherOp, options); |
| for (SmallVector<int64_t> elementOffsets : |
| StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| // To get the unrolled gather, extract the same slice based on the |
| // decomposed shape from each of the index, mask, and pass-through |
| // vectors. |
| Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides); |
| Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); |
| Value passThruSubVec = |
| rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, gatherOp.getPassThru(), elementOffsets, *targetShape, |
| strides); |
| auto slicedGather = vector::GatherOp::create( |
| rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(), |
| indexSubVec, maskSubVec, passThruSubVec); |
| |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, slicedGather, result, elementOffsets, strides); |
| } |
| rewriter.replaceOp(gatherOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> { |
| UnrollLoadPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::LoadOp loadOp, |
| PatternRewriter &rewriter) const override { |
| VectorType vecType = loadOp.getVectorType(); |
| |
| auto targetShape = getTargetShape(options, loadOp); |
| if (!targetShape) |
| return failure(); |
| |
| Location loc = loadOp.getLoc(); |
| ArrayRef<int64_t> originalShape = vecType.getShape(); |
| SmallVector<int64_t> strides(targetShape->size(), 1); |
| |
| Value result = arith::ConstantOp::create(rewriter, loc, vecType, |
| rewriter.getZeroAttr(vecType)); |
| |
| SmallVector<int64_t> loopOrder = |
| getUnrollOrder(originalShape.size(), loadOp, options); |
| |
| auto targetVecType = |
| VectorType::get(*targetShape, vecType.getElementType()); |
| |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { |
| SmallVector<Value> indices = |
| sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets); |
| Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType, |
| loadOp.getBase(), indices); |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, slicedLoad, result, offsets, strides); |
| } |
| rewriter.replaceOp(loadOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> { |
| UnrollStorePattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::StoreOp storeOp, |
| PatternRewriter &rewriter) const override { |
| VectorType vecType = storeOp.getVectorType(); |
| |
| auto targetShape = getTargetShape(options, storeOp); |
| if (!targetShape) |
| return failure(); |
| |
| Location loc = storeOp.getLoc(); |
| ArrayRef<int64_t> originalShape = vecType.getShape(); |
| SmallVector<int64_t> strides(targetShape->size(), 1); |
| |
| Value base = storeOp.getBase(); |
| Value vector = storeOp.getValueToStore(); |
| |
| SmallVector<int64_t> loopOrder = |
| getUnrollOrder(originalShape.size(), storeOp, options); |
| |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { |
| SmallVector<Value> indices = |
| sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets); |
| Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, vector, offsets, *targetShape, strides); |
| vector::StoreOp::create(rewriter, loc, slice, base, indices); |
| } |
| rewriter.eraseOp(storeOp); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> { |
| UnrollBroadcastPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::BroadcastOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, |
| PatternRewriter &rewriter) const override { |
| auto targetShape = getTargetShape(options, broadcastOp); |
| if (!targetShape) |
| return failure(); |
| |
| Location loc = broadcastOp.getLoc(); |
| VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()); |
| VectorType resType = broadcastOp.getResultVectorType(); |
| VectorType targetType = |
| resType.cloneWith(*targetShape, resType.getElementType()); |
| Value result = arith::ConstantOp::create(rewriter, loc, resType, |
| rewriter.getZeroAttr(resType)); |
| |
| SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll(); |
| SmallVector<int64_t> strides(originalShape.size(), 1); |
| |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalShape, *targetShape)) { |
| Value newSrc; |
| if (!srcType) { |
| // Scalar to vector broadcast. |
| newSrc = broadcastOp.getSource(); |
| } else { |
| // Vector to vector broadcast. |
| int64_t rank = srcType.getRank(); |
| SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end()); |
| SmallVector<int64_t> srcShape(targetShape->end() - rank, |
| targetShape->end()); |
| SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end()); |
| // adjust the offset and shape for src if the corresponding dim is 1. |
| for (int64_t i = 0; i < rank; ++i) { |
| if (srcType.getDimSize(i) == 1) { |
| srcOffsets[i] = 0; |
| srcShape[i] = 1; |
| } |
| } |
| newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides); |
| } |
| |
| Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp, |
| newSrc, targetType); |
| |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, newOp->getResult(0), result, offsets, strides); |
| } |
| |
| rewriter.replaceOp(broadcastOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| /// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the |
| /// outermost dimension of the operand. For example: |
| /// |
| /// ``` |
| /// %0:4 = vector.to_elements %v : vector<2x2xf32> |
| /// |
| /// ==> |
| /// |
| /// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32> |
| /// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32> |
| /// %0:4 = vector.to_elements %v0 : vector<2x2xf32> |
| /// %1:4 = vector.to_elements %v1 : vector<2x2xf32> |
| /// ``` |
| /// |
| /// When this pattern is applied until a fixed-point is reached, |
| /// this will produce a sequence of 1-d from_elements |
| /// ops. |
| struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> { |
| UnrollToElements(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::ToElementsOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::ToElementsOp op, |
| PatternRewriter &rewriter) const override { |
| |
| TypedValue<VectorType> source = op.getSource(); |
| FailureOr<SmallVector<Value>> result = |
| vector::unrollVectorValue(source, rewriter); |
| if (failed(result)) { |
| return failure(); |
| } |
| SmallVector<Value> vectors = *result; |
| |
| SmallVector<Value> results; |
| for (Value vector : vectors) { |
| auto subElements = |
| vector::ToElementsOp::create(rewriter, op.getLoc(), vector); |
| llvm::append_range(results, subElements.getResults()); |
| } |
| rewriter.replaceOp(op, results); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| /// This pattern unrolls `vector.step` operations according to the provided |
| /// target unroll shape. It decomposes a large step vector into smaller step |
| /// vectors (segments) and assembles the result by inserting each computed |
| /// segment into the appropriate offset of the original vector. |
| /// |
| /// The pattern does not support scalable vectors and will fail to match them. |
| /// |
| /// For each segment, it adds the base step vector and the segment's offset, |
| /// then inserts the result into the output vector at the corresponding |
| /// position. |
| /// |
| /// Example: |
| /// Given a step operation: |
| /// %0 = vector.step : vector<8xindex> |
| /// |
| /// and a target unroll shape of <4>, the pattern produces: |
| /// |
| /// %base = vector.step : vector<4xindex> |
| /// %zero = arith.constant dense<0> : vector<8xindex> |
| /// %result0 = vector.insert_strided_slice %base, %zero |
| /// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex> |
| /// %offset = arith.constant dense<4> : vector<4xindex> |
| /// %segment1 = arith.addi %base, %offset : vector<4xindex> |
| /// %result1 = vector.insert_strided_slice %segment1, %result0 |
| /// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex> |
| /// |
| struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> { |
| UnrollStepPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::StepOp stepOp, |
| PatternRewriter &rewriter) const override { |
| std::optional<SmallVector<int64_t>> targetShape = |
| getTargetShape(options, stepOp); |
| if (!targetShape) |
| return failure(); |
| |
| VectorType vecType = stepOp.getType(); |
| if (vecType.isScalable()) { |
| // Scalable vectors are not supported by this pattern. |
| return failure(); |
| } |
| int64_t originalSize = vecType.getShape()[0]; |
| Location loc = stepOp.getLoc(); |
| SmallVector<int64_t> strides(1, 1); |
| |
| Value result = arith::ConstantOp::create(rewriter, loc, vecType, |
| rewriter.getZeroAttr(vecType)); |
| |
| auto targetVecType = |
| VectorType::get(*targetShape, vecType.getElementType()); |
| Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType); |
| for (const SmallVector<int64_t> &offsets : |
| StaticTileOffsetRange({originalSize}, *targetShape)) { |
| Value bcastOffset = arith::ConstantOp::create( |
| rewriter, loc, targetVecType, |
| DenseElementsAttr::get( |
| targetVecType, |
| IntegerAttr::get(targetVecType.getElementType(), offsets[0]))); |
| Value tileStep = |
| arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset); |
| |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, tileStep, result, offsets, strides); |
| } |
| rewriter.replaceOp(stepOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the |
| /// outermost dimension. For example: |
| /// ``` |
| /// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> |
| /// |
| /// ==> |
| /// |
| /// %0 = ub.poison : vector<2x3xf32> |
| /// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> |
| /// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> |
| /// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> |
| /// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> |
| /// ``` |
| /// |
| /// When this pattern is applied until a fixed-point is reached, |
| /// this will produce a sequence of 1-d from_elements |
| /// ops. |
| struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> { |
| UnrollFromElements(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::FromElementsOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::FromElementsOp op, |
| PatternRewriter &rewriter) const override { |
| ValueRange allElements = op.getElements(); |
| |
| auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, |
| VectorType subTy, int64_t index) { |
| size_t subTyNumElements = subTy.getNumElements(); |
| assert((index + 1) * subTyNumElements <= allElements.size() && |
| "out of bounds"); |
| ValueRange subElements = |
| allElements.slice(index * subTyNumElements, subTyNumElements); |
| return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); |
| }; |
| |
| return unrollVectorOp(op, rewriter, unrollFromElementsFn); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| /// This pattern unrolls `vector.create_mask` operations into smaller mask |
| /// operations based on the target unroll shape. Each unrolled slice computes |
| /// its local mask size in each dimension (d) as: |
| /// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]). |
| /// Example: |
| /// Given a create_mask operation: |
| /// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10 |
| /// elements |
| /// |
| /// and a target unroll shape of <4x8>, the pattern produces: |
| /// |
| /// %false = arith.constant dense<false> : vector<8x16xi1> |
| /// |
| /// Slice [0,0]: |
| /// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8 |
| /// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1> |
| /// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] |
| /// : vector<4x8xi1> into vector<8x16xi1> |
| /// Slice [0,8]: |
| /// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2 |
| /// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1> |
| /// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] |
| /// : vector<4x8xi1> into vector<8x16xi1> |
| /// Slice [4,0]: |
| /// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8 |
| /// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1> |
| /// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] |
| /// : vector<4x8xi1> into vector<8x16xi1> |
| /// Slice [4,8]: |
| /// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2 |
| /// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1> |
| /// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] |
| /// : vector<4x8xi1> into vector<8x16xi1> |
| struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> { |
| UnrollCreateMaskPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::CreateMaskOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, |
| PatternRewriter &rewriter) const override { |
| auto targetShape = getTargetShape(options, createMaskOp); |
| if (!targetShape) |
| return failure(); |
| |
| VectorType resultType = createMaskOp.getVectorType(); |
| SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll(); |
| Location loc = createMaskOp.getLoc(); |
| |
| Value result = arith::ConstantOp::create(rewriter, loc, resultType, |
| rewriter.getZeroAttr(resultType)); |
| VectorType targetVectorType = |
| VectorType::get(*targetShape, rewriter.getI1Type()); |
| SmallVector<int64_t> strides(targetShape->size(), 1); |
| |
| // In each dimension (d), each unrolled vector computes its mask size as: |
| // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]). |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(originalSize, *targetShape)) { |
| SmallVector<Value> unrolledOperands; |
| |
| for (auto [i, originalMaskOperand] : |
| llvm::enumerate(createMaskOp.getOperands())) { |
| Value offsetVal = |
| arith::ConstantIndexOp::create(rewriter, loc, offsets[i]); |
| Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>( |
| loc, originalMaskOperand, offsetVal); |
| Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); |
| Value unrolledDimSize = |
| arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]); |
| Value nonNegative = |
| rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero); |
| Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>( |
| loc, nonNegative, unrolledDimSize); |
| unrolledOperands.push_back(unrolledOperand); |
| } |
| |
| auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>( |
| loc, targetVectorType, unrolledOperands); |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, unrolledMask, result, offsets, strides); |
| } |
| rewriter.replaceOp(createMaskOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| /// This pattern unrolls `vector.constant_mask` operations into smaller mask |
| /// operations based on the target unroll shape. Each unrolled slice computes |
| /// whether its elements should be masked based on the original mask dimensions |
| /// and the slice's offset position. |
| /// |
| /// Example: |
| /// Given a constant_mask operation: |
| /// %0 = vector.constant_mask [6, 10] : vector<8x16xi1> |
| /// |
| /// and a target unroll shape of <4x8>, the pattern produces: |
| /// |
| /// %false = arith.constant dense<false> : vector<8x16xi1> |
| /// |
| /// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds |
| /// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1> |
| /// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] |
| /// : vector<4x8xi1> into vector<8x16xi1> |
| /// |
| /// Slice [0,8]: elements [0:4, 8:16] - partially within bounds |
| /// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1> |
| /// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] |
| /// : vector<4x8xi1> into vector<8x16xi1> |
| /// |
| /// Slice [4,0]: elements [4:8, 0:8] - partially within bounds |
| /// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1> |
| /// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] |
| /// : vector<4x8xi1> into vector<8x16xi1> |
| /// |
| /// Slice [4,8]: elements [4:8, 8:16] - partially within bounds |
| /// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1> |
| /// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] |
| /// : vector<4x8xi1> into vector<8x16xi1> |
| struct UnrollConstantMaskPattern |
| : public OpRewritePattern<vector::ConstantMaskOp> { |
| UnrollConstantMaskPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::ConstantMaskOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp, |
| PatternRewriter &rewriter) const override { |
| std::optional<SmallVector<int64_t>> targetShape = |
| getTargetShape(options, constantMaskOp); |
| if (!targetShape) |
| return failure(); |
| |
| VectorType resultType = constantMaskOp.getVectorType(); |
| SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll(); |
| Location loc = constantMaskOp.getLoc(); |
| |
| Value result = arith::ConstantOp::create(rewriter, loc, resultType, |
| rewriter.getZeroAttr(resultType)); |
| VectorType targetVectorType = |
| VectorType::get(*targetShape, rewriter.getI1Type()); |
| SmallVector<int64_t> strides(targetShape->size(), 1); |
| |
| // In each dimension (d), each unrolled vector computes its mask size as: |
| // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]). |
| for (const SmallVector<int64_t> &offsets : |
| StaticTileOffsetRange(originalSize, *targetShape)) { |
| SmallVector<int64_t> unrolledMaskDims; |
| |
| for (auto [i, originalMaskDim] : |
| llvm::enumerate(constantMaskOp.getMaskDimSizes())) { |
| // Calculate how many elements in this dimension should be masked |
| // for this particular slice |
| int64_t adjustedMaskSize = |
| std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0)); |
| int64_t unrolledMaskDim = |
| std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i])); |
| unrolledMaskDims.push_back(unrolledMaskDim); |
| } |
| |
| auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>( |
| loc, targetVectorType, unrolledMaskDims); |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, unrolledMask, result, offsets, strides); |
| } |
| rewriter.replaceOp(constantMaskOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| /// Checks whether extractShape is a contiguous slice of shape. |
| /// For extractShape to be contiguous in shape: |
| /// 1) All but the leading dimension of extractShape and shape must match |
| /// exactly. 2) The total number of elements in shape must be evenly divisible |
| /// by |
| /// the total number of elements in extractShape. |
| /// Examples: |
| /// isContiguous([4, 4], [8, 4]) == true |
| /// isContiguous([2, 4], [8, 4]) == true |
| /// isContiguous([2, 2], [8, 4]) == false |
| /// Removes leading unit dimensions to handle cases like: |
| /// isContiguous([1, 16], [1, 32]) == true |
| static bool isContiguous(ArrayRef<int64_t> extractShape, |
| ArrayRef<int64_t> shape) { |
| |
| if (extractShape.empty() || shape.empty() || |
| extractShape.size() > shape.size()) |
| return false; |
| |
| while (extractShape.size() > 1 && extractShape.front() == 1) |
| extractShape = extractShape.drop_front(); |
| |
| while (shape.size() > 1 && shape.front() == 1) { |
| shape = shape.drop_front(); |
| } |
| |
| size_t rankDiff = shape.size() - extractShape.size(); |
| if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1))) |
| return false; |
| |
| int64_t extractElements = ShapedType::getNumElements(extractShape); |
| int64_t shapeElements = ShapedType::getNumElements(shape); |
| return shapeElements % extractElements == 0; |
| } |
| |
| /// Determines what shape to use with `vector.extract_strided_slice` to extract |
| /// a contiguous memory region from a source vector. The extraction must be |
| /// contiguous and contain exactly the specified number of elements. If such an |
| /// extraction shape cannot be determined, returns std::nullopt. |
| /// EXAMPLE 1: |
| /// sourceShape = [16], targetElements = 8 |
| /// Working right-to-left: |
| /// - Take min(8, 16) = 8 from only dim → extractShape = [8], |
| /// remaining = 8/8 = 1 |
| /// Result: [8] |
| /// |
| /// EXAMPLE 2: |
| /// sourceShape = [4, 4], targetElements = 8 |
| /// Working right-to-left: |
| /// - Take min(8, 4) = 4 from last dim → extractShape = [4], |
| /// remaining = 8/4 = 2 |
| /// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4], |
| /// remaining = 2/2 = 1 |
| /// Result: [2, 4] |
| static std::optional<SmallVector<int64_t>> |
| calculateSourceExtractShape(ArrayRef<int64_t> sourceShape, |
| int64_t targetElements) { |
| SmallVector<int64_t> extractShape; |
| int64_t remainingElements = targetElements; |
| |
| // Build extract shape from innermost dimension outward to ensure contiguity. |
| for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) { |
| int64_t takeFromDim = std::min(remainingElements, sourceShape[i]); |
| extractShape.insert(extractShape.begin(), takeFromDim); |
| |
| if (remainingElements % takeFromDim != 0) |
| return std::nullopt; // Not evenly divisible. |
| remainingElements /= takeFromDim; |
| } |
| |
| // Fill remaining dimensions with 1. |
| while (extractShape.size() < sourceShape.size()) |
| extractShape.insert(extractShape.begin(), 1); |
| |
| if (ShapedType::getNumElements(extractShape) != targetElements) |
| return std::nullopt; |
| |
| return extractShape; |
| } |
| |
| // Convert result offsets to source offsets via linear position. |
| static SmallVector<int64_t> |
| calculateSourceOffsets(ArrayRef<int64_t> resultOffsets, |
| ArrayRef<int64_t> sourceShape, |
| ArrayRef<int64_t> resultShape) { |
| // Convert result offsets to linear position. |
| int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape)); |
| // Convert linear position to source offsets. |
| return delinearize(linearIndex, computeStrides(sourceShape)); |
| } |
| |
| /// This pattern unrolls `vector.shape_cast` operations according to the |
| /// provided target unroll shape. It unrolls a large shape cast into smaller |
| /// shape casts by extracting contiguous slices from the source vector, casting |
| /// each slice to the target shape, and assembling the result by inserting each |
| /// computed segment into the appropriate offset of the result vector. |
| /// |
| /// This pattern only applies when contiguous slices can be extracted from the |
| /// source vector and inserted into the result vector such that each slice |
| /// remains a valid vector (and not decompose to scalars). In these cases, the |
| /// unrolling proceeds as: |
| /// vector.extract_strided_slice -> vector.shape_cast (on the slice) -> |
| /// vector.insert_strided_slice. |
| /// |
| /// Example: |
| /// Given a shape cast operation: |
| /// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32> |
| /// |
| /// and a target unroll shape of <2x4>, the pattern produces: |
| /// |
| /// %zero = arith.constant dense<0.0> : vector<4x4xf32> |
| /// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1] |
| /// : vector<8x2xf32> to vector<4x2xf32> |
| /// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32> |
| /// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1] |
| /// : vector<2x4xf32> into vector<4x4xf32> |
| /// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1] |
| /// : vector<8x2xf32> to vector<4x2xf32> |
| /// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32> |
| /// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1] |
| /// : vector<2x4xf32> into vector<4x4xf32> |
| /// |
| struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> { |
| UnrollShapeCastPattern(MLIRContext *context, |
| const vector::UnrollVectorOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<vector::ShapeCastOp>(context, benefit), |
| options(options) {} |
| |
| LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, |
| PatternRewriter &rewriter) const override { |
| std::optional<SmallVector<int64_t>> targetShape = |
| getTargetShape(options, shapeCastOp); |
| if (!targetShape) |
| return failure(); |
| |
| VectorType sourceType = shapeCastOp.getSourceVectorType(); |
| VectorType resultType = shapeCastOp.getResultVectorType(); |
| ArrayRef<int64_t> sourceShape = sourceType.getShape(); |
| ArrayRef<int64_t> resultShape = resultType.getShape(); |
| |
| if (!isContiguous(*targetShape, resultShape)) |
| return rewriter.notifyMatchFailure( |
| shapeCastOp, "Only supports cases where target shape is " |
| "contiguous in result vector shape"); |
| |
| int64_t targetElements = ShapedType::getNumElements(*targetShape); |
| |
| // Calculate the shape to extract from source. |
| std::optional<SmallVector<int64_t>> extractShape = |
| calculateSourceExtractShape(sourceShape, targetElements); |
| if (!extractShape) |
| return rewriter.notifyMatchFailure( |
| shapeCastOp, |
| "cannot extract target number of elements contiguously from source"); |
| |
| Location loc = shapeCastOp.getLoc(); |
| |
| // Create result vector initialized to zero. |
| Value result = arith::ConstantOp::create(rewriter, loc, resultType, |
| rewriter.getZeroAttr(resultType)); |
| |
| VectorType targetType = |
| VectorType::get(*targetShape, sourceType.getElementType()); |
| |
| SmallVector<int64_t> extractStrides(extractShape->size(), 1); |
| SmallVector<int64_t> insertStrides(targetShape->size(), 1); |
| |
| for (SmallVector<int64_t> resultOffsets : |
| StaticTileOffsetRange(resultShape, *targetShape)) { |
| SmallVector<int64_t> sourceOffsets = |
| calculateSourceOffsets(resultOffsets, sourceShape, resultShape); |
| Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, shapeCastOp.getSource(), sourceOffsets, *extractShape, |
| extractStrides); |
| Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>( |
| loc, targetType, sourceChunk); |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, targetChunk, result, resultOffsets, insertStrides); |
| } |
| |
| rewriter.replaceOp(shapeCastOp, result); |
| return success(); |
| } |
| |
| private: |
| vector::UnrollVectorOptions options; |
| }; |
| |
| } // namespace |
| |
| void mlir::vector::populateVectorUnrollPatterns( |
| RewritePatternSet &patterns, const UnrollVectorOptions &options, |
| PatternBenefit benefit) { |
| patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, |
| UnrollContractionPattern, UnrollElementwisePattern, |
| UnrollReductionPattern, UnrollMultiReductionPattern, |
| UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, |
| UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, |
| UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern, |
| UnrollCreateMaskPattern, UnrollConstantMaskPattern>( |
| patterns.getContext(), options, benefit); |
| } |
| |
| void mlir::vector::populateVectorToElementsUnrollPatterns( |
| RewritePatternSet &patterns, PatternBenefit benefit) { |
| patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(), |
| benefit); |
| } |
| |
| void mlir::vector::populateVectorFromElementsUnrollPatterns( |
| RewritePatternSet &patterns, PatternBenefit benefit) { |
| patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(), |
| benefit); |
| } |