| //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===// |
| // |
| /// Part of the LLVM Project, under the Apache License v2.0 with LLVM |
| /// Exceptions. See https://llvm.org/LICENSE.txt for license information. |
| /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements target-independent rewrites and utilities to lower the |
| // 'vector.multi_reduction' operation. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| #include "mlir/Dialect/Vector/Transforms/Passes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace vector { |
| #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION |
| #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
| } // namespace vector |
| } // namespace mlir |
| |
| #define DEBUG_TYPE "vector-multi-reduction" |
| |
| using namespace mlir; |
| |
| namespace { |
| /// This file implements the following transformations as composable atomic |
| /// patterns. |
| |
| /// Converts vector.multi_reduction into inner-most/outer-most reduction form |
| /// by using vector.transpose |
| class InnerOuterDimReductionConversion |
| : public OpRewritePattern<vector::MultiDimReductionOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| explicit InnerOuterDimReductionConversion( |
| MLIRContext *context, vector::VectorMultiReductionLowering options, |
| PatternBenefit benefit = 1) |
| : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
| useInnerDimsForReduction( |
| options == vector::VectorMultiReductionLowering::InnerReduction) {} |
| |
| LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
| PatternRewriter &rewriter) const override { |
| // Vector mask setup. |
| OpBuilder::InsertionGuard guard(rewriter); |
| auto maskableOp = |
| cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
| Operation *rootOp; |
| if (maskableOp.isMasked()) { |
| rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
| rootOp = maskableOp.getMaskingOp(); |
| } else { |
| rootOp = multiReductionOp; |
| } |
| |
| auto src = multiReductionOp.getSource(); |
| auto loc = multiReductionOp.getLoc(); |
| auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
| |
| // Separate reduction and parallel dims |
| auto reductionDimsRange = |
| multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>(); |
| auto reductionDims = llvm::to_vector<4>(llvm::map_range( |
| reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); |
| llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), |
| reductionDims.end()); |
| int64_t reductionSize = reductionDims.size(); |
| SmallVector<int64_t, 4> parallelDims; |
| for (int64_t i = 0; i < srcRank; ++i) |
| if (!reductionDimsSet.contains(i)) |
| parallelDims.push_back(i); |
| |
| // Add transpose only if inner-most/outer-most dimensions are not parallel |
| // and there are parallel dims. |
| if (parallelDims.empty()) |
| return failure(); |
| if (useInnerDimsForReduction && |
| (parallelDims == |
| llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) |
| return failure(); |
| |
| if (!useInnerDimsForReduction && |
| (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>( |
| reductionDims.size(), |
| parallelDims.size() + reductionDims.size())))) |
| return failure(); |
| |
| SmallVector<int64_t, 4> indices; |
| if (useInnerDimsForReduction) { |
| indices.append(parallelDims.begin(), parallelDims.end()); |
| indices.append(reductionDims.begin(), reductionDims.end()); |
| } else { |
| indices.append(reductionDims.begin(), reductionDims.end()); |
| indices.append(parallelDims.begin(), parallelDims.end()); |
| } |
| |
| // If masked, transpose the original mask. |
| Value transposedMask; |
| if (maskableOp.isMasked()) { |
| transposedMask = rewriter.create<vector::TransposeOp>( |
| loc, maskableOp.getMaskingOp().getMask(), indices); |
| } |
| |
| // Transpose reduction source. |
| auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); |
| SmallVector<bool> reductionMask(srcRank, false); |
| for (int i = 0; i < reductionSize; ++i) { |
| if (useInnerDimsForReduction) |
| reductionMask[srcRank - i - 1] = true; |
| else |
| reductionMask[i] = true; |
| } |
| |
| Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>( |
| multiReductionOp.getLoc(), transposeOp.getResult(), |
| multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); |
| newMultiRedOp = |
| mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask); |
| |
| rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0)); |
| return success(); |
| } |
| |
| private: |
| const bool useInnerDimsForReduction; |
| }; |
| |
| /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction |
| /// dimensions are either inner most or outer most. |
| class ReduceMultiDimReductionRank |
| : public OpRewritePattern<vector::MultiDimReductionOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| explicit ReduceMultiDimReductionRank( |
| MLIRContext *context, vector::VectorMultiReductionLowering options, |
| PatternBenefit benefit = 1) |
| : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
| useInnerDimsForReduction( |
| options == vector::VectorMultiReductionLowering::InnerReduction) {} |
| |
| LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
| PatternRewriter &rewriter) const override { |
| // Vector mask setup. |
| OpBuilder::InsertionGuard guard(rewriter); |
| auto maskableOp = |
| cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
| Operation *rootOp; |
| if (maskableOp.isMasked()) { |
| rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
| rootOp = maskableOp.getMaskingOp(); |
| } else { |
| rootOp = multiReductionOp; |
| } |
| |
| auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
| auto srcShape = multiReductionOp.getSourceVectorType().getShape(); |
| auto srcScalableDims = |
| multiReductionOp.getSourceVectorType().getScalableDims(); |
| auto loc = multiReductionOp.getLoc(); |
| |
| // If rank less than 2, nothing to do. |
| if (srcRank < 2) |
| return failure(); |
| |
| // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g. |
| // `vscale * vscale` that's currently not modelled. |
| if (llvm::count(srcScalableDims, true) > 1) |
| return failure(); |
| |
| // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. |
| SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); |
| if (srcRank == 2 && reductionMask.front() != reductionMask.back()) |
| return failure(); |
| |
| // 1. Separate reduction and parallel dims. |
| SmallVector<int64_t, 4> parallelDims, parallelShapes; |
| SmallVector<bool, 4> parallelScalableDims; |
| SmallVector<int64_t, 4> reductionDims, reductionShapes; |
| bool isReductionDimScalable = false; |
| for (const auto &it : llvm::enumerate(reductionMask)) { |
| int64_t i = it.index(); |
| bool isReduction = it.value(); |
| if (isReduction) { |
| reductionDims.push_back(i); |
| reductionShapes.push_back(srcShape[i]); |
| isReductionDimScalable |= srcScalableDims[i]; |
| } else { |
| parallelDims.push_back(i); |
| parallelShapes.push_back(srcShape[i]); |
| parallelScalableDims.push_back(srcScalableDims[i]); |
| } |
| } |
| |
| // 2. Compute flattened parallel and reduction sizes. |
| int flattenedParallelDim = 0; |
| int flattenedReductionDim = 0; |
| if (!parallelShapes.empty()) { |
| flattenedParallelDim = 1; |
| for (auto d : parallelShapes) |
| flattenedParallelDim *= d; |
| } |
| if (!reductionShapes.empty()) { |
| flattenedReductionDim = 1; |
| for (auto d : reductionShapes) |
| flattenedReductionDim *= d; |
| } |
| // We must at least have some parallel or some reduction. |
| assert((flattenedParallelDim || flattenedReductionDim) && |
| "expected at least one parallel or reduction dim"); |
| |
| // 3. Fail if reduction/parallel dims are not contiguous. |
| // Check parallelDims are exactly [0 .. size). |
| int64_t counter = 0; |
| if (useInnerDimsForReduction && |
| llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) |
| return failure(); |
| // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). |
| counter = reductionDims.size(); |
| if (!useInnerDimsForReduction && |
| llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) |
| return failure(); |
| |
| // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into |
| // a single parallel (resp. reduction) dim. |
| SmallVector<bool, 2> mask; |
| SmallVector<bool, 2> scalableDims; |
| SmallVector<int64_t, 2> vectorShape; |
| bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true); |
| if (flattenedParallelDim) { |
| mask.push_back(false); |
| vectorShape.push_back(flattenedParallelDim); |
| scalableDims.push_back(isParallelDimScalable); |
| } |
| if (flattenedReductionDim) { |
| mask.push_back(true); |
| vectorShape.push_back(flattenedReductionDim); |
| scalableDims.push_back(isReductionDimScalable); |
| } |
| if (!useInnerDimsForReduction && vectorShape.size() == 2) { |
| std::swap(mask.front(), mask.back()); |
| std::swap(vectorShape.front(), vectorShape.back()); |
| std::swap(scalableDims.front(), scalableDims.back()); |
| } |
| |
| Value newVectorMask; |
| if (maskableOp.isMasked()) { |
| Value vectorMask = maskableOp.getMaskingOp().getMask(); |
| auto maskCastedType = VectorType::get( |
| vectorShape, |
| llvm::cast<VectorType>(vectorMask.getType()).getElementType()); |
| newVectorMask = |
| rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask); |
| } |
| |
| auto castedType = VectorType::get( |
| vectorShape, multiReductionOp.getSourceVectorType().getElementType(), |
| scalableDims); |
| Value cast = rewriter.create<vector::ShapeCastOp>( |
| loc, castedType, multiReductionOp.getSource()); |
| |
| Value acc = multiReductionOp.getAcc(); |
| if (flattenedParallelDim) { |
| auto accType = VectorType::get( |
| {flattenedParallelDim}, |
| multiReductionOp.getSourceVectorType().getElementType(), |
| /*scalableDims=*/{isParallelDimScalable}); |
| acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc); |
| } |
| // 6. Creates the flattened form of vector.multi_reduction with inner/outer |
| // most dim as reduction. |
| Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>( |
| loc, cast, acc, mask, multiReductionOp.getKind()); |
| newMultiDimRedOp = |
| mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask); |
| |
| // 7. If there are no parallel shapes, the result is a scalar. |
| // TODO: support 0-d vectors when available. |
| if (parallelShapes.empty()) { |
| rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0)); |
| return success(); |
| } |
| |
| // 8. Creates shape cast for the output n-D -> 2-D. |
| VectorType outputCastedType = VectorType::get( |
| parallelShapes, multiReductionOp.getSourceVectorType().getElementType(), |
| parallelScalableDims); |
| rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
| rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); |
| return success(); |
| } |
| |
| private: |
| const bool useInnerDimsForReduction; |
| }; |
| |
| /// Unrolls vector.multi_reduction with outermost reductions |
| /// and combines results |
| struct TwoDimMultiReductionToElementWise |
| : public OpRewritePattern<vector::MultiDimReductionOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
| PatternRewriter &rewriter) const override { |
| auto maskableOp = |
| cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
| if (maskableOp.isMasked()) |
| // TODO: Support masking. |
| return failure(); |
| |
| auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
| // Rank-2 ["parallel", "reduce"] or bail. |
| if (srcRank != 2) |
| return failure(); |
| |
| if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) |
| return failure(); |
| |
| auto loc = multiReductionOp.getLoc(); |
| ArrayRef<int64_t> srcShape = |
| multiReductionOp.getSourceVectorType().getShape(); |
| |
| Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); |
| if (!elementType.isIntOrIndexOrFloat()) |
| return failure(); |
| |
| Value result = multiReductionOp.getAcc(); |
| for (int64_t i = 0; i < srcShape[0]; i++) { |
| auto operand = rewriter.create<vector::ExtractOp>( |
| loc, multiReductionOp.getSource(), i); |
| result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), |
| operand, result); |
| } |
| |
| rewriter.replaceOp(multiReductionOp, result); |
| return success(); |
| } |
| }; |
| |
| /// Converts 2d vector.multi_reduction with inner most reduction dimension into |
| /// a sequence of vector.reduction ops. |
| struct TwoDimMultiReductionToReduction |
| : public OpRewritePattern<vector::MultiDimReductionOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
| PatternRewriter &rewriter) const override { |
| auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
| if (srcRank != 2) |
| return failure(); |
| |
| if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) |
| return failure(); |
| |
| // Vector mask setup. |
| OpBuilder::InsertionGuard guard(rewriter); |
| auto maskableOp = |
| cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
| Operation *rootOp; |
| if (maskableOp.isMasked()) { |
| rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
| rootOp = maskableOp.getMaskingOp(); |
| } else { |
| rootOp = multiReductionOp; |
| } |
| |
| auto loc = multiReductionOp.getLoc(); |
| Value result = rewriter.create<arith::ConstantOp>( |
| loc, multiReductionOp.getDestType(), |
| rewriter.getZeroAttr(multiReductionOp.getDestType())); |
| int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; |
| |
| for (int i = 0; i < outerDim; ++i) { |
| auto v = rewriter.create<vector::ExtractOp>( |
| loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i}); |
| auto acc = rewriter.create<vector::ExtractOp>( |
| loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i}); |
| Operation *reductionOp = rewriter.create<vector::ReductionOp>( |
| loc, multiReductionOp.getKind(), v, acc); |
| |
| // If masked, slice the mask and mask the new reduction operation. |
| if (maskableOp.isMasked()) { |
| Value mask = rewriter.create<vector::ExtractOp>( |
| loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i}); |
| reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); |
| } |
| |
| result = rewriter.create<vector::InsertElementOp>( |
| loc, reductionOp->getResult(0), result, |
| rewriter.create<arith::ConstantIndexOp>(loc, i)); |
| } |
| |
| rewriter.replaceOp(rootOp, result); |
| return success(); |
| } |
| }; |
| |
| /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d |
| /// form with both a single parallel and reduction dimension. |
| /// This is achieved with a simple vector.shape_cast that inserts a leading 1. |
| /// The case with a single parallel dimension is a noop and folds away |
| /// separately. |
| struct OneDimMultiReductionToTwoDim |
| : public OpRewritePattern<vector::MultiDimReductionOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
| PatternRewriter &rewriter) const override { |
| auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
| // Rank-1 or bail. |
| if (srcRank != 1) |
| return failure(); |
| |
| // Vector mask setup. |
| OpBuilder::InsertionGuard guard(rewriter); |
| auto maskableOp = |
| cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
| Operation *rootOp; |
| Value mask; |
| if (maskableOp.isMasked()) { |
| rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
| rootOp = maskableOp.getMaskingOp(); |
| mask = maskableOp.getMaskingOp().getMask(); |
| } else { |
| rootOp = multiReductionOp; |
| } |
| |
| auto loc = multiReductionOp.getLoc(); |
| auto srcVectorType = multiReductionOp.getSourceVectorType(); |
| auto srcShape = srcVectorType.getShape(); |
| auto castedType = VectorType::get( |
| ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(), |
| ArrayRef<bool>{false, srcVectorType.getScalableDims().back()}); |
| |
| auto accType = |
| VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType()); |
| assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) && |
| "multi_reduction with a single dimension expects a scalar result"); |
| |
| // If the unique dim is reduced and we insert a parallel in front, we need a |
| // {false, true} mask. |
| SmallVector<bool, 2> reductionMask{false, true}; |
| |
| /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) |
| Value cast = rewriter.create<vector::ShapeCastOp>( |
| loc, castedType, multiReductionOp.getSource()); |
| Value castAcc = rewriter.create<vector::BroadcastOp>( |
| loc, accType, multiReductionOp.getAcc()); |
| Value castMask; |
| if (maskableOp.isMasked()) { |
| auto maskType = llvm::cast<VectorType>(mask.getType()); |
| auto castMaskType = VectorType::get( |
| ArrayRef<int64_t>{1, maskType.getShape().back()}, |
| maskType.getElementType(), |
| ArrayRef<bool>{false, maskType.getScalableDims().back()}); |
| castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask); |
| } |
| |
| Operation *newOp = rewriter.create<vector::MultiDimReductionOp>( |
| loc, cast, castAcc, reductionMask, multiReductionOp.getKind()); |
| newOp = vector::maskOperation(rewriter, newOp, castMask); |
| |
| rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0), |
| ArrayRef<int64_t>{0}); |
| return success(); |
| } |
| }; |
| |
| struct LowerVectorMultiReductionPass |
| : public vector::impl::LowerVectorMultiReductionBase< |
| LowerVectorMultiReductionPass> { |
| LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) { |
| this->loweringStrategy = option; |
| } |
| |
| void runOnOperation() override { |
| Operation *op = getOperation(); |
| MLIRContext *context = op->getContext(); |
| |
| RewritePatternSet loweringPatterns(context); |
| populateVectorMultiReductionLoweringPatterns(loweringPatterns, |
| this->loweringStrategy); |
| |
| if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) |
| signalPassFailure(); |
| } |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<vector::VectorDialect>(); |
| } |
| }; |
| |
| } // namespace |
| |
| void mlir::vector::populateVectorMultiReductionLoweringPatterns( |
| RewritePatternSet &patterns, VectorMultiReductionLowering options, |
| PatternBenefit benefit) { |
| patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( |
| patterns.getContext(), options, benefit); |
| patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit); |
| if (options == VectorMultiReductionLowering ::InnerReduction) |
| patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(), |
| benefit); |
| else |
| patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(), |
| benefit); |
| } |
| |
| std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass( |
| vector::VectorMultiReductionLowering option) { |
| return std::make_unique<LowerVectorMultiReductionPass>(option); |
| } |