| //===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H |
| #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H |
| |
| #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| |
| namespace mlir { |
| class RewritePatternSet; |
| |
| namespace vector { |
| |
| //===----------------------------------------------------------------------===// |
| // Lowering pattern populate functions |
| //===----------------------------------------------------------------------===// |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [OuterProductOpLowering] |
| /// Progressively lower a `vector.outerproduct` to linearized |
| /// `vector.extract` + `vector.fma` + `vector.insert`. |
| /// |
| /// [ContractionOpLowering] |
| /// Progressive lowering of ContractionOp. |
| /// One: |
| /// %x = vector.contract with at least one free/batch dimension |
| /// is replaced by: |
| /// %a = vector.contract with one less free/batch dimension |
| /// %b = vector.contract with one less free/batch dimension |
| /// |
| /// [ContractionOpToMatmulOpLowering] |
| /// Progressively lower a `vector.contract` with row-major matmul semantics to |
| /// linearized `vector.shape_cast` + `vector.matmul` on the way to |
| /// `llvm.matrix.multiply`. |
| /// |
| /// [ContractionOpToDotLowering] |
| /// Progressively lower a `vector.contract` with row-major matmul semantics to |
| /// linearized `vector.extract` + `vector.reduce` + `vector.insert`. |
| /// |
| /// [ContractionOpToOuterProductOpLowering] |
| /// Progressively lower a `vector.contract` with row-major matmul semantics to |
| /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`. |
| void populateVectorContractLoweringPatterns( |
| RewritePatternSet &patterns, VectorTransformsOptions options, |
| PatternBenefit benefit = 1, bool disableOuterProductLowering = false); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [OuterProductOpLowering] |
| /// Progressively lower a `vector.outerproduct` to linearized |
| /// `vector.extract` + `vector.fma` + `vector.insert`. |
| void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit = 1); |
| |
| /// Collect a set of patterns to convert vector.multi_reduction op into |
| /// a sequence of vector.reduction ops. The patterns comprise: |
| /// |
| /// [InnerOuterDimReductionConversion] |
| /// Rewrites vector.multi_reduction such that all reduction dimensions are |
| /// either innermost or outermost, by adding the proper vector.transpose |
| /// operations. |
| /// |
| /// [ReduceMultiDimReductionRank] |
| /// Once in innermost or outermost reduction |
| /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, |
| /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand |
| /// back. |
| /// |
| /// [TwoDimMultiReductionToElementWise] |
| /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction |
| /// dimension, unroll the outer dimension to obtain a sequence of 1-D vector |
| /// ops. This also has an opportunity for tree-reduction (in the future). |
| /// |
| /// [TwoDimMultiReductionToReduction] |
| /// Once in 2-D vector.multi_reduction form, with an **innermost** reduction |
| /// dimension, unroll the outer dimension to obtain a sequence of extract + |
| /// vector.reduction + insert. This can further lower to horizontal reduction |
| /// ops. |
| /// |
| /// [OneDimMultiReductionToTwoDim] |
| /// For cases that reduce to 1-D vector<k> reduction (and are thus missing |
| /// either a parallel or a reduction), we lift them back up to 2-D with a simple |
| /// vector.shape_cast to vector<1xk> so that the other patterns can kick in, |
| /// thus fully exiting out of the vector.multi_reduction abstraction. |
| void populateVectorMultiReductionLoweringPatterns( |
| RewritePatternSet &patterns, VectorMultiReductionLowering options, |
| PatternBenefit benefit = 1); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [TransferReadToVectorLoadLowering] |
| /// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D |
| /// BroadcastOp until dim 1. |
| void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit = 1); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [CreateMaskOp] |
| /// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1. |
| /// |
| /// [ConstantMaskOp] |
| /// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until |
| /// dim 1. |
| void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit = 1); |
| |
| /// Collects patterns that lower scalar vector transfer ops to memref loads and |
| /// stores when beneficial. If `allowMultipleUses` is set to true, the patterns |
| /// are applied to vector transfer reads with any number of uses. Otherwise, |
| /// only vector transfer reads with a single use will be lowered. |
| void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit, |
| bool allowMultipleUses); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [ShapeCastOp2DDownCastRewritePattern] |
| /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D |
| /// vectors progressively. |
| /// |
| /// [ShapeCastOp2DUpCastRewritePattern] |
| /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D |
| /// vectors progressively. |
| /// |
| /// [ShapeCastOpRewritePattern] |
| /// Reference lowering to fully unrolled sequences of single element ExtractOp + |
| /// InsertOp. Note that applying this pattern can almost always be considered a |
| /// performance bug. |
| void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit = 1); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [TransposeOpLowering] |
| /// |
| /// [TransposeOp2DToShuffleLowering] |
| /// |
| void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, |
| VectorTransformsOptions options, |
| PatternBenefit benefit = 1); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [TransferReadToVectorLoadLowering] |
| /// Progressive lowering of transfer_read.This pattern supports lowering of |
| /// `vector.transfer_read` to a combination of `vector.load` and |
| /// `vector.broadcast` |
| /// |
| /// [TransferWriteToVectorStoreLowering] |
| /// Progressive lowering of transfer_write. This pattern supports lowering of |
| /// `vector.transfer_write` to `vector.store` |
| /// |
| /// [VectorLoadToMemrefLoadLowering] |
| /// Replace a 0-d vector.load with a memref.load + vector.broadcast. |
| /// |
| /// [VectorStoreToMemrefStoreLowering] |
| /// Replace a 0-d vector.store with a vector.extractelement + memref.store. |
| /// |
| /// These patterns lower transfer ops to simpler ops like `vector.load`, |
| /// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank |
| /// of a most `maxTransferRank` are lowered. This is useful when combined with |
| /// VectorToSCF, which reduces the rank of vector transfer ops. |
| void populateVectorTransferLoweringPatterns( |
| RewritePatternSet &patterns, |
| std::optional<unsigned> maxTransferRank = std::nullopt, |
| PatternBenefit benefit = 1); |
| |
| /// Collect a set of transfer read/write lowering patterns that simplify the |
| /// permutation map (e.g., converting it to a minor identity map) by inserting |
| /// broadcasts and transposes. More specifically: |
| /// |
| /// [TransferReadPermutationLowering] |
| /// Lower transfer_read op with permutation into a transfer_read with a |
| /// permutation map composed of leading zeros followed by a minor identity + |
| /// vector.transpose op. |
| /// Ex: |
| /// vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2) -> (0, d1) |
| /// into: |
| /// %v = vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2) -> (d1, 0) |
| /// vector.transpose %v, [1, 0] |
| /// |
| /// vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) |
| /// into: |
| /// %v = vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) |
| /// vector.transpose %v, [0, 1, 3, 2, 4] |
| /// Note that an alternative is to transform it to linalg.transpose + |
| /// vector.transfer_read to do the transpose in memory instead. |
| /// |
| /// [TransferWritePermutationLowering] |
| /// Lower transfer_write op with permutation into a transfer_write with a |
| /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) |
| /// Ex: |
| /// vector.transfer_write %v ... |
| /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) |
| /// into: |
| /// %tmp = vector.transpose %v, [2, 0, 1] |
| /// vector.transfer_write %tmp ... |
| /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) |
| /// |
| /// vector.transfer_write %v ... |
| /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) |
| /// into: |
| /// %tmp = vector.transpose %v, [1, 0] |
| /// %v = vector.transfer_write %tmp ... |
| /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) |
| /// |
| /// [TransferOpReduceRank] |
| /// Lower transfer_read op with broadcast in the leading dimensions into |
| /// transfer_read of lower rank + vector.broadcast. |
| /// Ex: vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) |
| /// into: |
| /// %v = vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) |
| /// vector.broadcast %v |
| void populateVectorTransferPermutationMapLoweringPatterns( |
| RewritePatternSet &patterns, PatternBenefit benefit = 1); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [ScanToArithOps] |
| /// Convert vector.scan op into arith ops and vector.insert_strided_slice / |
| /// vector.extract_strided_slice. |
| void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit = 1); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [FlattenGather] |
| /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the |
| /// outermost dimension. |
| /// |
| /// [Gather1DToConditionalLoads] |
| /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or |
| /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these |
| /// loads/extracts are made conditional using `scf.if` ops. |
| void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit = 1); |
| |
| /// Populates instances of `MaskOpRewritePattern` to lower masked operations |
| /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and |
| /// not its nested `MaskableOpInterface`. |
| void populateVectorMaskLoweringPatternsForSideEffectingOps( |
| RewritePatternSet &patterns); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [VectorMaskedLoadOpConverter] |
| /// Turns vector.maskedload to scf.if + memref.load |
| /// |
| /// [VectorMaskedStoreOpConverter] |
| /// Turns vector.maskedstore to scf.if + memref.store |
| void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit = 1); |
| |
| /// Populate the pattern set with the following patterns: |
| /// |
| /// [UnrollInterleaveOp] |
| /// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp + |
| /// InterleaveOp (of `targetRank`) + InsertOp. |
| void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, |
| int64_t targetRank = 1, |
| PatternBenefit benefit = 1); |
| |
| void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, |
| PatternBenefit benefit = 1); |
| |
| } // namespace vector |
| } // namespace mlir |
| #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H |