| //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements target-independent rewrites and utilities to lower the |
| // 'vector.mask' operation. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| #include "mlir/Dialect/Vector/Transforms/Passes.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "lower-vector-mask" |
| |
| namespace mlir { |
| namespace vector { |
| #define GEN_PASS_DEF_LOWERVECTORMASKPASS |
| #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
| } // namespace vector |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace mlir::vector; |
| |
| //===----------------------------------------------------------------------===// |
| // populateVectorMaskOpLoweringPatterns |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// Progressive lowering of CreateMaskOp. |
| /// One: |
| /// %x = vector.create_mask %a, ... : vector<dx...> |
| /// is replaced by: |
| /// %l = vector.create_mask ... : vector<...> ; one lower rank |
| /// %0 = arith.cmpi "slt", %ci, %a | |
| /// %1 = select %0, %l, %zeroes | |
| /// %r = vector.insert %1, %pr [i] | d-times |
| /// %x = .... |
| /// until a one-dimensional vector is reached. |
| class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::CreateMaskOp op, |
| PatternRewriter &rewriter) const override { |
| auto dstType = cast<VectorType>(op.getResult().getType()); |
| int64_t rank = dstType.getRank(); |
| if (rank <= 1) |
| return rewriter.notifyMatchFailure( |
| op, "0-D and 1-D vectors are handled separately"); |
| |
| if (dstType.getScalableDims().front()) |
| return rewriter.notifyMatchFailure( |
| op, "Cannot unroll leading scalable dim in dstType"); |
| |
| auto loc = op.getLoc(); |
| int64_t dim = dstType.getDimSize(0); |
| Value idx = op.getOperand(0); |
| |
| VectorType lowType = VectorType::Builder(dstType).dropDim(0); |
| Value trueVal = rewriter.create<vector::CreateMaskOp>( |
| loc, lowType, op.getOperands().drop_front()); |
| Value falseVal = rewriter.create<arith::ConstantOp>( |
| loc, lowType, rewriter.getZeroAttr(lowType)); |
| Value result = rewriter.create<arith::ConstantOp>( |
| loc, dstType, rewriter.getZeroAttr(dstType)); |
| for (int64_t d = 0; d < dim; d++) { |
| Value bnd = |
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); |
| Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, |
| bnd, idx); |
| Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); |
| result = rewriter.create<vector::InsertOp>(loc, sel, result, d); |
| } |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| /// Progressive lowering of ConstantMaskOp. |
| /// One: |
| /// %x = vector.constant_mask [a,b] |
| /// is replaced by: |
| /// %z = zero-result |
| /// %l = vector.constant_mask [b] |
| /// %4 = vector.insert %l, %z[0] |
| /// .. |
| /// %x = vector.insert %l, %..[a-1] |
| /// until a one-dimensional vector is reached. All these operations |
| /// will be folded at LLVM IR level. |
| class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::ConstantMaskOp op, |
| PatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| auto dstType = op.getType(); |
| auto dimSizes = op.getMaskDimSizes(); |
| int64_t rank = dstType.getRank(); |
| |
| if (rank == 0) { |
| assert(dimSizes.size() == 1 && |
| "Expected exactly one dim size for a 0-D vector"); |
| bool value = dimSizes.front() == 1; |
| rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| op, dstType, |
| DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()), |
| value)); |
| return success(); |
| } |
| |
| int64_t trueDimSize = dimSizes.front(); |
| |
| if (rank == 1) { |
| if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { |
| // Use constant splat for 'all set' or 'none set' dims. |
| // This produces correct code for scalable dimensions (it will lower to |
| // a constant splat). |
| rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| op, DenseElementsAttr::get(dstType, trueDimSize != 0)); |
| } else { |
| // Express constant 1-D case in explicit vector form: |
| // [T,..,T,F,..,F]. |
| // Note: The verifier would reject this case for scalable vectors. |
| SmallVector<bool> values(dstType.getDimSize(0), false); |
| for (int64_t d = 0; d < trueDimSize; d++) |
| values[d] = true; |
| rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| op, dstType, rewriter.getBoolVectorAttr(values)); |
| } |
| return success(); |
| } |
| |
| if (dstType.getScalableDims().front()) |
| return rewriter.notifyMatchFailure( |
| op, "Cannot unroll leading scalable dim in dstType"); |
| |
| VectorType lowType = VectorType::Builder(dstType).dropDim(0); |
| Value trueVal = rewriter.create<vector::ConstantMaskOp>( |
| loc, lowType, dimSizes.drop_front()); |
| Value result = rewriter.create<arith::ConstantOp>( |
| loc, dstType, rewriter.getZeroAttr(dstType)); |
| for (int64_t d = 0; d < trueDimSize; d++) |
| result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d); |
| |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::vector::populateVectorMaskOpLoweringPatterns( |
| RewritePatternSet &patterns, PatternBenefit benefit) { |
| patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( |
| patterns.getContext(), benefit); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // populateVectorMaskLoweringPatternsForSideEffectingOps |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold |
| /// matching: |
| /// 1. It matches a `vector.mask` operation. |
| /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested |
| /// in the matched `vector.mask` operation. |
| /// |
| /// It is required that the replacement op in the pattern replaces the |
| /// `vector.mask` operation and not the nested `MaskableOpInterface`. This |
| /// approach allows having patterns that "stop" at every `vector.mask` operation |
| /// and actually match the traits of its the nested `MaskableOpInterface`. |
| template <class SourceOp> |
| struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { |
| using OpRewritePattern<MaskOp>::OpRewritePattern; |
| |
| private: |
| LogicalResult matchAndRewrite(MaskOp maskOp, |
| PatternRewriter &rewriter) const final { |
| auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp()); |
| if (!maskableOp) |
| return failure(); |
| SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation()); |
| if (!sourceOp) |
| return failure(); |
| |
| return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter); |
| } |
| |
| protected: |
| virtual LogicalResult |
| matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, |
| PatternRewriter &rewriter) const = 0; |
| }; |
| |
| /// Lowers a masked `vector.transfer_read` operation. |
| struct MaskedTransferReadOpPattern |
| : public MaskOpRewritePattern<TransferReadOp> { |
| public: |
| using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern; |
| |
| LogicalResult |
| matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp, |
| PatternRewriter &rewriter) const override { |
| // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read' |
| // expects a scalar. We could only lower one to the other for cases where |
| // the passthru is a broadcast of a scalar. |
| if (maskingOp.hasPassthru()) |
| return rewriter.notifyMatchFailure( |
| maskingOp, "Can't lower passthru to vector.transfer_read"); |
| |
| // Replace the `vector.mask` operation. |
| rewriter.replaceOpWithNewOp<TransferReadOp>( |
| maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(), |
| readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), |
| maskingOp.getMask(), readOp.getInBounds()); |
| return success(); |
| } |
| }; |
| |
| /// Lowers a masked `vector.transfer_write` operation. |
| struct MaskedTransferWriteOpPattern |
| : public MaskOpRewritePattern<TransferWriteOp> { |
| public: |
| using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern; |
| |
| LogicalResult |
| matchAndRewriteMaskableOp(TransferWriteOp writeOp, |
| MaskingOpInterface maskingOp, |
| PatternRewriter &rewriter) const override { |
| Type resultType = |
| writeOp.getResult() ? writeOp.getResult().getType() : Type(); |
| |
| // Replace the `vector.mask` operation. |
| rewriter.replaceOpWithNewOp<TransferWriteOp>( |
| maskingOp.getOperation(), resultType, writeOp.getVector(), |
| writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(), |
| maskingOp.getMask(), writeOp.getInBounds()); |
| return success(); |
| } |
| }; |
| |
| /// Lowers a masked `vector.gather` operation. |
| struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> { |
| public: |
| using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern; |
| |
| LogicalResult |
| matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, |
| PatternRewriter &rewriter) const override { |
| Value passthru = maskingOp.hasPassthru() |
| ? maskingOp.getPassthru() |
| : rewriter.create<arith::ConstantOp>( |
| gatherOp.getLoc(), |
| rewriter.getZeroAttr(gatherOp.getVectorType())); |
| |
| // Replace the `vector.mask` operation. |
| rewriter.replaceOpWithNewOp<GatherOp>( |
| maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), |
| gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), |
| passthru); |
| return success(); |
| } |
| }; |
| |
| struct LowerVectorMaskPass |
| : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> { |
| using Base::Base; |
| |
| void runOnOperation() override { |
| Operation *op = getOperation(); |
| MLIRContext *context = op->getContext(); |
| |
| RewritePatternSet loweringPatterns(context); |
| populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); |
| MaskOp::getCanonicalizationPatterns(loweringPatterns, context); |
| |
| if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) |
| signalPassFailure(); |
| } |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<vector::VectorDialect>(); |
| } |
| }; |
| |
| } // namespace |
| |
| /// Populates instances of `MaskOpRewritePattern` to lower masked operations |
| /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and |
| /// not its nested `MaskableOpInterface`. |
| void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( |
| RewritePatternSet &patterns) { |
| patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern, |
| MaskedGatherOpPattern>(patterns.getContext()); |
| } |
| |
| std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() { |
| return std::make_unique<LowerVectorMaskPass>(); |
| } |