| //===- CodegenStrategy.h - Linalg programmable codegen strategy -*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ |
| #define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ |
| |
| #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
| #include "mlir/Dialect/Linalg/Passes.h" |
| #include "mlir/Pass/PassManager.h" |
| |
| namespace mlir { |
| |
| class FuncOp; |
| |
| namespace linalg { |
| |
| /// Abstract Transformation class applied in a sequence that also handles state |
| /// through markers. |
| struct Transformation { |
| explicit Transformation(LinalgTransformationFilter::FilterFunction f) |
| : filter(f) {} |
| virtual ~Transformation() = default; |
| virtual void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const = 0; |
| LinalgTransformationFilter::FilterFunction filter = nullptr; |
| }; |
| |
| /// Represent one application of LinalgStrategyTileAndFusePass. |
| struct TileAndFuse : public Transformation { |
| TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) |
| : Transformation(f), opName(name), options(options) {} |
| |
| void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const override { |
| pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m)); |
| } |
| |
| private: |
| std::string opName; |
| linalg::LinalgTilingAndFusionOptions options; |
| }; |
| |
| /// Represent one application of LinalgStrategyTilePass. |
| struct Tile : public Transformation { |
| Tile(StringRef name, linalg::LinalgTilingOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) |
| : Transformation(f), opName(name), options(options) {} |
| |
| void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const override { |
| pm.addPass(createLinalgStrategyTilePass(opName, options, m)); |
| } |
| |
| private: |
| std::string opName; |
| linalg::LinalgTilingOptions options; |
| }; |
| |
| /// Represent one application of LinalgStrategyPadPass. |
| struct Pad : public Transformation { |
| Pad(StringRef name, linalg::LinalgPaddingOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) |
| : Transformation(f), opName(name), options(options) {} |
| |
| void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const override { |
| pm.addPass(createLinalgStrategyPadPass(opName, options, m)); |
| } |
| |
| private: |
| std::string opName; |
| linalg::LinalgPaddingOptions options; |
| }; |
| |
| /// Represent one application of createLinalgStrategyPromotePass. |
| struct Promote : public Transformation { |
| Promote(StringRef name, linalg::LinalgPromotionOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) |
| : Transformation(f), opName(name), options(options) {} |
| |
| void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const override { |
| pm.addPass(createLinalgStrategyPromotePass(opName, options, m)); |
| } |
| |
| private: |
| std::string opName; |
| linalg::LinalgPromotionOptions options; |
| }; |
| |
| /// Represent one application of createLinalgStrategyGeneralizePass. |
| struct Generalize : public Transformation { |
| explicit Generalize(StringRef name, |
| LinalgTransformationFilter::FilterFunction f = nullptr) |
| : Transformation(f), opName(name) {} |
| |
| void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const override { |
| pm.addPass(createLinalgStrategyGeneralizePass(opName, m)); |
| } |
| |
| private: |
| std::string opName; |
| }; |
| |
| /// Represent one application of createLinalgStrategyInterchangePass. |
| struct Interchange : public Transformation { |
| explicit Interchange(ArrayRef<int64_t> iteratorInterchange, |
| LinalgTransformationFilter::FilterFunction f = nullptr) |
| : Transformation(f), iteratorInterchange(iteratorInterchange.begin(), |
| iteratorInterchange.end()) {} |
| |
| void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const override { |
| pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m)); |
| } |
| |
| private: |
| SmallVector<int64_t> iteratorInterchange; |
| }; |
| |
| /// Represent one application of createLinalgStrategyVectorizePass. |
| struct Vectorize : public Transformation { |
| explicit Vectorize(linalg::LinalgVectorizationOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr, |
| bool padVectorize = false) |
| : Transformation(f), opName(), options(options), |
| vectorizePadding(padVectorize) {} |
| |
| Vectorize(StringRef name, linalg::LinalgVectorizationOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr, |
| bool padVectorize = false) |
| : Transformation(f), opName(name), options(options), |
| vectorizePadding(padVectorize) {} |
| |
| void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const override { |
| pm.addPass(createLinalgStrategyVectorizePass(opName, options, m, |
| vectorizePadding)); |
| } |
| |
| private: |
| std::string opName; |
| linalg::LinalgVectorizationOptions options; |
| bool vectorizePadding; |
| }; |
| |
| /// Represent one application of createLinalgStrategyLowerVectorsPass. |
| struct VectorLowering : public Transformation { |
| explicit VectorLowering( |
| linalg::LinalgVectorLoweringOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) |
| : Transformation(f), options(options) {} |
| |
| void addToPassPipeline(OpPassManager &pm, |
| LinalgTransformationFilter m) const override { |
| pm.addPass(createLinalgStrategyLowerVectorsPass(options, m)); |
| } |
| |
| private: |
| linalg::LinalgVectorLoweringOptions options; |
| }; |
| |
| /// Codegen strategy controls how a Linalg op is progressively lowered. |
| struct CodegenStrategy { |
| /// Append a pattern to tile the Op `opName` and fuse its producers with |
| /// tiling and fusion `options`. |
| CodegenStrategy & |
| tileAndFuse(StringRef opName, LinalgTilingAndFusionOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| transformationSequence.emplace_back( |
| std::make_unique<TileAndFuse>(opName, options, f)); |
| return *this; |
| } |
| /// Conditionally append a pattern to tile the Op `opName` and fuse its |
| /// producers with tiling and fusion `options`. |
| CodegenStrategy & |
| tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| return b ? tileAndFuse(opName, options, f) : *this; |
| } |
| /// Append a pattern to add a level of tiling for Op `opName` with tiling |
| /// `options`. |
| CodegenStrategy & |
| tile(StringRef opName, linalg::LinalgTilingOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| transformationSequence.emplace_back( |
| std::make_unique<Tile>(opName, options, f)); |
| return *this; |
| } |
| /// Conditionally append a pattern to add a level of tiling for |
| /// `LinalgOpType` with tiling `options`. |
| CodegenStrategy & |
| tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| return b ? tile(opName, options, f) : *this; |
| } |
| /// Append a pattern to pad and hoist the operands of Op `opName` with padding |
| /// `options`. |
| CodegenStrategy &pad(StringRef opName, linalg::LinalgPaddingOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| transformationSequence.emplace_back( |
| std::make_unique<Pad>(opName, options, f)); |
| return *this; |
| } |
| /// Conditionally append a pattern to pad and hoist the operands of Op |
| /// `opName` with padding `options`. |
| CodegenStrategy & |
| padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| return b ? pad(opName, options, f) : *this; |
| } |
| /// Append a pattern to add a level of promotion for `LinalgOpType` with |
| /// promotion `options`. |
| CodegenStrategy & |
| promote(StringRef opName, linalg::LinalgPromotionOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| transformationSequence.emplace_back( |
| std::make_unique<Promote>(opName, options, f)); |
| return *this; |
| } |
| /// Conditionally append a pattern to add a level of promotion for |
| /// `LinalgOpType` with promotion `options`. |
| CodegenStrategy & |
| promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| return b ? promote(opName, options, f) : *this; |
| return *this; |
| } |
| /// Append a pattern to generalize named operations. |
| CodegenStrategy & |
| generalize(StringRef opName, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| transformationSequence.emplace_back( |
| std::make_unique<Generalize>(opName, f)); |
| return *this; |
| } |
| /// Conditionally append a pattern to generalize named operations. |
| CodegenStrategy & |
| generalizeIf(bool b, StringRef opName, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| return b ? generalize(opName, f) : *this; |
| return *this; |
| } |
| /// Append a pattern to interchange iterators. |
| CodegenStrategy & |
| interchange(ArrayRef<int64_t> iteratorInterchange, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| transformationSequence.emplace_back( |
| std::make_unique<Interchange>(iteratorInterchange, f)); |
| return *this; |
| } |
| /// Conditionally append a pattern to interchange iterators. |
| CodegenStrategy & |
| interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange, |
| LinalgTransformationFilter::FilterFunction f = nullptr) { |
| return b ? interchange(iteratorInterchange, f) : *this; |
| return *this; |
| } |
| /// Append a pattern to rewrite `LinalgOpType` as a vector operation. |
| CodegenStrategy & |
| vectorize(StringRef opName, |
| LinalgTransformationFilter::FilterFunction f = nullptr, |
| bool vectorizePadding = false) { |
| assert(!opName.empty() && "expected an op name"); |
| transformationSequence.emplace_back(std::make_unique<Vectorize>( |
| opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding)); |
| return *this; |
| } |
| /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector |
| /// operation. |
| CodegenStrategy & |
| vectorizeIf(bool b, StringRef opName, |
| LinalgTransformationFilter::FilterFunction f = nullptr, |
| bool vectorizePadding = false) { |
| return b ? vectorize(opName, f, vectorizePadding) : *this; |
| return *this; |
| } |
| /// Append a pattern to lower all vector operations. |
| CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) { |
| transformationSequence.emplace_back( |
| std::make_unique<VectorLowering>(options)); |
| return *this; |
| } |
| /// Configure the post staged-patterns global enabling passes options. |
| CodegenStrategy & |
| setVectorTransferToSCFOptions(LinalgEnablingOptions options) { |
| linalgEnablingOptions = options; |
| return *this; |
| } |
| |
| /// Apply the transformation patterns in sequence with cleanup |
| /// transformations interleaved. |
| void configurePassPipeline(OpPassManager &pm, MLIRContext *context, |
| bool addEnablePass = true) const; |
| |
| private: |
| LogicalResult postPatternTransforms(Operation *func) const; |
| |
| LinalgEnablingOptions linalgEnablingOptions; |
| SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence; |
| }; |
| |
| } // namespace linalg |
| } // namespace mlir |
| |
| #endif // MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ |