//===- CodegenStrategy.h - Linalg programmable codegen strategy -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
#define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_

#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/PassManager.h"

namespace mlir {

class FuncOp;

namespace linalg {

/// Abstract Transformation class applied in a sequence that also handles state
/// through markers.
struct Transformation {
  explicit Transformation(LinalgTransformationFilter::FilterFunction f)
      : filter(f) {}
  virtual ~Transformation() = default;
  virtual void addToPassPipeline(OpPassManager &pm,
                                 LinalgTransformationFilter m) const = 0;
  LinalgTransformationFilter::FilterFunction filter = nullptr;
};

/// Represent one application of LinalgStrategyTileAndFusePass.
struct TileAndFuse : public Transformation {
  TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
              LinalgTransformationFilter::FilterFunction f = nullptr)
      : Transformation(f), opName(name), options(options) {}

  void addToPassPipeline(OpPassManager &pm,
                         LinalgTransformationFilter m) const override {
    pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m));
  }

private:
  std::string opName;
  linalg::LinalgTilingAndFusionOptions options;
};

/// Represent one application of LinalgStrategyTilePass.
struct Tile : public Transformation {
  Tile(StringRef name, linalg::LinalgTilingOptions options,
       LinalgTransformationFilter::FilterFunction f = nullptr)
      : Transformation(f), opName(name), options(options) {}

  void addToPassPipeline(OpPassManager &pm,
                         LinalgTransformationFilter m) const override {
    pm.addPass(createLinalgStrategyTilePass(opName, options, m));
  }

private:
  std::string opName;
  linalg::LinalgTilingOptions options;
};

/// Represent one application of LinalgStrategyPadPass.
struct Pad : public Transformation {
  Pad(StringRef name, linalg::LinalgPaddingOptions options,
      LinalgTransformationFilter::FilterFunction f = nullptr)
      : Transformation(f), opName(name), options(options) {}

  void addToPassPipeline(OpPassManager &pm,
                         LinalgTransformationFilter m) const override {
    pm.addPass(createLinalgStrategyPadPass(opName, options, m));
  }

private:
  std::string opName;
  linalg::LinalgPaddingOptions options;
};

/// Represent one application of createLinalgStrategyPromotePass.
struct Promote : public Transformation {
  Promote(StringRef name, linalg::LinalgPromotionOptions options,
          LinalgTransformationFilter::FilterFunction f = nullptr)
      : Transformation(f), opName(name), options(options) {}

  void addToPassPipeline(OpPassManager &pm,
                         LinalgTransformationFilter m) const override {
    pm.addPass(createLinalgStrategyPromotePass(opName, options, m));
  }

private:
  std::string opName;
  linalg::LinalgPromotionOptions options;
};

/// Represent one application of createLinalgStrategyGeneralizePass.
struct Generalize : public Transformation {
  explicit Generalize(StringRef name,
                      LinalgTransformationFilter::FilterFunction f = nullptr)
      : Transformation(f), opName(name) {}

  void addToPassPipeline(OpPassManager &pm,
                         LinalgTransformationFilter m) const override {
    pm.addPass(createLinalgStrategyGeneralizePass(opName, m));
  }

private:
  std::string opName;
};

/// Represent one application of createLinalgStrategyInterchangePass.
struct Interchange : public Transformation {
  explicit Interchange(ArrayRef<int64_t> iteratorInterchange,
                       LinalgTransformationFilter::FilterFunction f = nullptr)
      : Transformation(f), iteratorInterchange(iteratorInterchange.begin(),
                                               iteratorInterchange.end()) {}

  void addToPassPipeline(OpPassManager &pm,
                         LinalgTransformationFilter m) const override {
    pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m));
  }

private:
  SmallVector<int64_t> iteratorInterchange;
};

/// Represent one application of createLinalgStrategyVectorizePass.
struct Vectorize : public Transformation {
  explicit Vectorize(linalg::LinalgVectorizationOptions options,
                     LinalgTransformationFilter::FilterFunction f = nullptr,
                     bool padVectorize = false)
      : Transformation(f), opName(), options(options),
        vectorizePadding(padVectorize) {}

  Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
            LinalgTransformationFilter::FilterFunction f = nullptr,
            bool padVectorize = false)
      : Transformation(f), opName(name), options(options),
        vectorizePadding(padVectorize) {}

  void addToPassPipeline(OpPassManager &pm,
                         LinalgTransformationFilter m) const override {
    pm.addPass(createLinalgStrategyVectorizePass(opName, options, m,
                                                 vectorizePadding));
  }

private:
  std::string opName;
  linalg::LinalgVectorizationOptions options;
  bool vectorizePadding;
};

/// Represent one application of createLinalgStrategyLowerVectorsPass.
struct VectorLowering : public Transformation {
  explicit VectorLowering(
      linalg::LinalgVectorLoweringOptions options,
      LinalgTransformationFilter::FilterFunction f = nullptr)
      : Transformation(f), options(options) {}

  void addToPassPipeline(OpPassManager &pm,
                         LinalgTransformationFilter m) const override {
    pm.addPass(createLinalgStrategyLowerVectorsPass(options, m));
  }

private:
  linalg::LinalgVectorLoweringOptions options;
};

/// Codegen strategy controls how a Linalg op is progressively lowered.
struct CodegenStrategy {
  /// Append a pattern to tile the Op `opName` and fuse its producers with
  /// tiling and fusion `options`.
  CodegenStrategy &
  tileAndFuse(StringRef opName, LinalgTilingAndFusionOptions options,
              LinalgTransformationFilter::FilterFunction f = nullptr) {
    transformationSequence.emplace_back(
        std::make_unique<TileAndFuse>(opName, options, f));
    return *this;
  }
  /// Conditionally append a pattern to tile the Op `opName` and fuse its
  /// producers with tiling and fusion `options`.
  CodegenStrategy &
  tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options,
                LinalgTransformationFilter::FilterFunction f = nullptr) {
    return b ? tileAndFuse(opName, options, f) : *this;
  }
  /// Append a pattern to add a level of tiling for Op `opName` with tiling
  /// `options`.
  CodegenStrategy &
  tile(StringRef opName, linalg::LinalgTilingOptions options,
       LinalgTransformationFilter::FilterFunction f = nullptr) {
    transformationSequence.emplace_back(
        std::make_unique<Tile>(opName, options, f));
    return *this;
  }
  /// Conditionally append a pattern to add a level of tiling for
  /// `LinalgOpType` with tiling `options`.
  CodegenStrategy &
  tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
         LinalgTransformationFilter::FilterFunction f = nullptr) {
    return b ? tile(opName, options, f) : *this;
  }
  /// Append a pattern to pad and hoist the operands of Op `opName` with padding
  /// `options`.
  CodegenStrategy &pad(StringRef opName, linalg::LinalgPaddingOptions options,
                       LinalgTransformationFilter::FilterFunction f = nullptr) {
    transformationSequence.emplace_back(
        std::make_unique<Pad>(opName, options, f));
    return *this;
  }
  /// Conditionally append a pattern to pad and hoist the operands of Op
  /// `opName` with padding `options`.
  CodegenStrategy &
  padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options,
        LinalgTransformationFilter::FilterFunction f = nullptr) {
    return b ? pad(opName, options, f) : *this;
  }
  /// Append a pattern to add a level of promotion for `LinalgOpType` with
  /// promotion `options`.
  CodegenStrategy &
  promote(StringRef opName, linalg::LinalgPromotionOptions options,
          LinalgTransformationFilter::FilterFunction f = nullptr) {
    transformationSequence.emplace_back(
        std::make_unique<Promote>(opName, options, f));
    return *this;
  }
  /// Conditionally append a pattern to add a level of promotion for
  /// `LinalgOpType` with promotion `options`.
  CodegenStrategy &
  promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options,
            LinalgTransformationFilter::FilterFunction f = nullptr) {
    return b ? promote(opName, options, f) : *this;
    return *this;
  }
  /// Append a pattern to generalize named operations.
  CodegenStrategy &
  generalize(StringRef opName,
             LinalgTransformationFilter::FilterFunction f = nullptr) {
    transformationSequence.emplace_back(
        std::make_unique<Generalize>(opName, f));
    return *this;
  }
  /// Conditionally append a pattern to generalize named operations.
  CodegenStrategy &
  generalizeIf(bool b, StringRef opName,
               LinalgTransformationFilter::FilterFunction f = nullptr) {
    return b ? generalize(opName, f) : *this;
    return *this;
  }
  /// Append a pattern to interchange iterators.
  CodegenStrategy &
  interchange(ArrayRef<int64_t> iteratorInterchange,
              LinalgTransformationFilter::FilterFunction f = nullptr) {
    transformationSequence.emplace_back(
        std::make_unique<Interchange>(iteratorInterchange, f));
    return *this;
  }
  /// Conditionally append a pattern to interchange iterators.
  CodegenStrategy &
  interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange,
                LinalgTransformationFilter::FilterFunction f = nullptr) {
    return b ? interchange(iteratorInterchange, f) : *this;
    return *this;
  }
  /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
  CodegenStrategy &
  vectorize(StringRef opName,
            LinalgTransformationFilter::FilterFunction f = nullptr,
            bool vectorizePadding = false) {
    transformationSequence.emplace_back(std::make_unique<Vectorize>(
        opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding));
    return *this;
  }
  /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
  /// operation.
  CodegenStrategy &
  vectorizeIf(bool b, StringRef opName,
              LinalgTransformationFilter::FilterFunction f = nullptr,
              bool vectorizePadding = false) {
    return b ? vectorize(opName, f, vectorizePadding) : *this;
    return *this;
  }
  /// Append a pattern to lower all vector operations.
  CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) {
    transformationSequence.emplace_back(
        std::make_unique<VectorLowering>(options));
    return *this;
  }
  /// Configure the post staged-patterns global enabling passes options.
  CodegenStrategy &
  setVectorTransferToSCFOptions(LinalgEnablingOptions options) {
    linalgEnablingOptions = options;
    return *this;
  }

  /// Apply the transformation patterns in sequence with cleanup
  /// transformations interleaved.
  void configurePassPipeline(OpPassManager &pm, MLIRContext *context,
                             bool addEnablePass = true) const;

private:
  LogicalResult postPatternTransforms(Operation *func) const;

  LinalgEnablingOptions linalgEnablingOptions;
  SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
};

} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
