//===- BufferizableOpInterface.h - Comprehensive Bufferize ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_

#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/SetVector.h"

namespace mlir {
class BlockAndValueMapping;
class DominanceInfo;
class FuncOp;

namespace linalg {
namespace comprehensive_bufferize {

// TODO: from some HW description.
static constexpr int64_t kBufferAlignments = 128;

struct BufferizationState;

/// Callback functions that are used to allocate/deallocate/copy memory buffers.
/// Comprehensive Bufferize provides default implementations of these functions.
// TODO: Could be replaced with a "bufferization strategy" object with virtual
// functions in the future.
struct AllocationCallbacks {
  using AllocationFn = std::function<Optional<Value>(
      OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
  using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
  using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;

  AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
                      MemCpyFn copyFn)
      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}

  /// A function that allocates memory.
  AllocationFn allocationFn;

  /// A function that deallocated memory. Must be allocated by `allocationFn`.
  DeallocationFn deallocationFn;

  /// A function that copies memory between two allocations.
  MemCpyFn memCpyFn;
};

/// Return default allocation callbacks.
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();

/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used
/// implement custom dialect-specific optimizations.
struct PostAnalysisStep {
  virtual ~PostAnalysisStep() {}

  /// Run the post analysis step. This function may modify the IR, but must keep
  /// `aliasInfo` (inside `state`) consistent. Newly created operations and
  /// operations that should be re-analyzed must be stored in `newOps`.
  virtual LogicalResult run(FuncOp funcOp, BufferizationState &state,
                            SmallVector<Operation *> &newOps) = 0;
};

/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
  BufferizationOptions();

  // BufferizationOptions cannot be copied.
  BufferizationOptions(const BufferizationOptions &other) = delete;

  /// Register a "post analysis" step. Such steps are executed after the
  /// analysis, but before bufferization.
  template <typename Step, typename... Args>
  void addPostAnalysisStep(Args... args) {
    postAnalysisSteps.emplace_back(
        std::make_unique<Step>(std::forward<Args>(args)...));
  }

  /// Helper functions for allocation, deallocation, memory copying.
  std::unique_ptr<AllocationCallbacks> allocationFns;

  /// Specifies whether returning newly allocated memrefs should be allowed.
  /// Otherwise, a pass failure is triggered.
  bool allowReturnMemref = false;

  /// Specifies whether not bufferizable ops are allowed in the input. If so,
  /// bufferization.to_memref and bufferization.to_tensor ops are inserted at
  /// the boundaries.
  bool allowUnknownOps = false;

  /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
  /// Should be used only with `testAnalysisOnly = true`.
  unsigned analysisFuzzerSeed = 0;

  /// If set to `true`, does not modify the IR apart from adding attributes (for
  /// checking the results of the analysis) and post analysis steps.
  bool testAnalysisOnly = false;

  /// Registered post analysis steps.
  std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
};

/// Specify fine-grain relationship between buffers to enable more analysis.
enum class BufferRelation {
  None,
  // TODO: ResultContainsOperand,
  // TODO: OperandContainsResult,
  Equivalent
};

/// The BufferizationAliasInfo class maintains a list of buffer aliases and
/// equivalence classes to support bufferization.
class BufferizationAliasInfo {
public:
  explicit BufferizationAliasInfo(Operation *rootOp);

  // BufferizationAliasInfo should be passed as a reference.
  BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;

  /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
  /// beginning the alias and equivalence sets only contain `v` itself.
  void createAliasInfoEntry(Value v);

  /// Insert an info entry for `newValue` and merge its alias set with that of
  /// `alias`.
  void insertNewBufferAlias(Value newValue, Value alias);

  /// Insert an info entry for `newValue` and merge its alias set with that of
  /// `alias`. Additionally, merge their equivalence classes.
  void insertNewBufferEquivalence(Value newValue, Value alias);

  /// Set the inPlace bufferization spec to true.
  /// Merge result's and operand's aliasing sets and iterate to a fixed point.
  void bufferizeInPlace(OpResult result, OpOperand &operand);

  /// Set the inPlace bufferization spec to false.
  void bufferizeOutOfPlace(OpResult result);

  /// Return true if `v1` and `v2` bufferize to equivalent buffers.
  bool areEquivalentBufferizedValues(Value v1, Value v2) const {
    return equivalentInfo.isEquivalent(v1, v2);
  }

  /// Return true if `v1` and `v2` bufferize to aliasing buffers.
  bool areAliasingBufferizedValues(Value v1, Value v2) const {
    return aliasInfo.isEquivalent(v1, v2);
  }

  /// Union the alias sets of `v1` and `v2`.
  void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }

  /// Union the equivalence classes of `v1` and `v2`.
  void unionEquivalenceClasses(Value v1, Value v2) {
    equivalentInfo.unionSets(v1, v2);
  }

  /// Apply `fun` to all the members of the equivalence class of `v`.
  void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;

  /// Apply `fun` to all aliases of `v`.
  void applyOnAliases(Value v, function_ref<void(Value)> fun) const;

  // TODO: Move these out of BufferizationAliasInfo.
  /// Return true if the value is known to bufferize to writable memory.
  bool bufferizesToWritableMemory(Value v) const;

  /// Specify that the value is known to bufferize to writable memory.
  void setBufferizesToWritableMemory(Value v);

  /// Mark a value as in-place bufferized.
  void markInPlace(OpResult v) { inplaceBufferized.insert(v); }

  /// Return `true` if a value was marked as in-place bufferized.
  bool isInPlace(OpResult opResult) const;

private:
  /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
  /// uses pointer comparison on the defining op. This is a poor man's
  /// comparison but it's not like UnionFind needs ordering anyway.
  struct ValueComparator {
    bool operator()(const Value &lhs, const Value &rhs) const {
      return lhs.getImpl() < rhs.getImpl();
    }
  };

  using EquivalenceClassRangeType = llvm::iterator_range<
      llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
  /// Check that aliasInfo for `v` exists and return a reference to it.
  EquivalenceClassRangeType getAliases(Value v) const;

  /// Set of tensors that are known to bufferize to writable memory.
  llvm::DenseSet<Value> bufferizeToWritableMemory;

  /// Set of all OpResults that were decided to bufferize in-place.
  llvm::DenseSet<OpResult> inplaceBufferized;

  /// Auxiliary structure to store all the values a given value may alias with.
  /// Alias information is "may be" conservative: In the presence of branches, a
  /// value may alias with one of multiple other values. The concrete aliasing
  /// value may not even be known at compile time. All such values are
  /// considered to be aliases.
  llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;

  /// Auxiliary structure to store all the equivalent buffer classes. Equivalent
  /// buffer information is "must be" conservative: Only if two values are
  /// guaranteed to be equivalent at runtime, they said to be equivalent. It is
  /// possible that, in the presence of branches, it cannot be determined
  /// statically if two values are equivalent. In that case, the values are
  /// considered to be not equivalent.
  llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
};

/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);

/// Determine which OpResult will alias with `opOperand` if the op is bufferized
/// in place. Return an empty OpResult if the op is not bufferizable.
OpResult getAliasingOpResult(OpOperand &opOperand);

/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
/// op is not bufferizable.
bool bufferizesToMemoryRead(OpOperand &opOperand);

/// Return true if `opOperand` bufferizes to a memory write. Return
/// `true` if the op is not bufferizable.
bool bufferizesToMemoryWrite(OpOperand &opOperand);

/// Return true if `opOperand` does neither read nor write but bufferizes to an
/// alias. Return false if the op is not bufferizable.
bool bufferizesToAliasOnly(OpOperand &opOperand);

/// Return true if the given value is read by an op that bufferizes to a memory
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
bool isValueRead(Value value);

/// Return the relationship between the operand and the its corresponding
/// OpResult that it may alias with. Return None if the op is not bufferizable.
BufferRelation bufferRelation(OpOperand &opOperand);

/// Starting from `value`, follow the use-def chain in reverse, always selecting
/// the aliasing OpOperands. Find and return Values for which `condition`
/// evaluates to true. OpOperands of such matching Values are not traversed any
/// further.
///
/// When reaching the end of a chain (BlockArgument or Value without aliasing
/// OpOperands), also return the last Value of that chain.
///
/// Example:
///
///                               8
///                               |
///   6*         7*         +-----+----+
///   |          |          |          |
///   2*         3          4*         5
///   |          |          |          |
///   +----------+----------+----------+
///              |
///              1
///
/// In the above example, Values with a star satisfy the condition. When
/// starting the traversal from Value 1, the resulting SetVector is:
/// { 2, 7, 8, 5 }
llvm::SetVector<Value>
findValueInReverseUseDefChain(Value value,
                              std::function<bool(Value)> condition);

/// Find the Value of the last preceding write of a given Value.
///
/// Note: Unknown ops are handled conservatively and assumed to be writes.
/// Furthermore, BlockArguments are also assumed to be writes. There is no
/// analysis across block boundaries.
///
/// Note: When reaching an end of the reverse SSA use-def chain, that value
/// is returned regardless of whether it is a memory write or not.
Value findLastPrecedingWrite(Value value);

/// Dialect-specific bufferization state. Analysis/bufferization information
/// that is specific to ops from a certain dialect can be stored in derived
/// variants of this struct.
struct DialectBufferizationState {
  virtual ~DialectBufferizationState() = default;
};

/// BufferizationState keeps track of bufferization state and provides access to
/// the results of the analysis.
struct BufferizationState {
  BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options)
      : aliasInfo(moduleOp), options(options) {}

  // BufferizationState should be passed as a reference.
  BufferizationState(const BufferizationState &) = delete;

  /// A function that creates an alloc-dealloc pair. This function may perform
  /// additional optimizations such as buffer allocation hoisting. This function
  /// calls `allocationFn` and `deallocationFn` to create (de)allocations.
  Value createAllocDeallocFn(OpBuilder &builder, Location loc,
                             Value shapedValue);

  /// Map tensor values to memref buffers.
  void mapBuffer(ValueRange tensors, ValueRange buffers);

  /// Map a value to another value.
  void mapValue(Value from, Value to);

  /// Map a tensor value to a memref buffer.
  void mapBuffer(Value tensor, Value buffer);

  /// Lookup the memref buffer that is associated to the given tensor value.
  /// Asserts if no buffer is associated.
  Value lookupBuffer(Value tensor);

  /// Lookup the value that is associated to the given value. Asserts if no
  /// value is associated.
  Value lookupValue(Value value) const;

  /// Return `true` if the given value is mapped.
  bool isMapped(Value value) const;

  /// Mark `op` as obsolete, so that it is deleted after bufferization.
  void markOpObsolete(Operation *op);

  /// Erase all ops that were marked obsolete.
  void eraseObsoleteOps();

  /// Return dialect-specific bufferization state.
  template <typename StateT> StateT &getDialectState(StringRef name) {
    // Create state if it does not exist yet.
    if (!dialectState.count(name))
      dialectState[name] = std::make_unique<StateT>();
    return static_cast<StateT &>(*dialectState[name]);
  }

  /// `aliasInfo` keeps track of aliasing and equivalent values.
  BufferizationAliasInfo aliasInfo;

  /// The mapping of tensors to buffers. May also contain mappings of non-tensor
  /// values.
  BlockAndValueMapping mapping;

  /// Obsolete ops that should be deleted after bufferization.
  SmallVector<Operation *> obsoleteOps;

  /// Dialect-specific bufferization state.
  DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;

  /// A reference to current bufferization options.
  const BufferizationOptions &options;
};

/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);

/// Bufferize all ops in the given region.
LogicalResult bufferize(Region *region, BufferizationState &state);

/// Bufferize all ops in the given block.
LogicalResult bufferize(Block *block, BufferizationState &state);

/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
/// function returns immediately. Otherwise, it calls the `bufferize` interface
/// method of `BufferizableOpInterface`.
LogicalResult bufferize(Operation *op, BufferizationState &state);

/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace`.
MemRefType getContiguousMemRefType(ShapedType shapedType,
                                   MemRefLayoutAttrInterface layout = {},
                                   Attribute memorySpace = {});

/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace` or an UnrankedMemRefType otherwise.
Type getContiguousOrUnrankedMemRefType(Type type,
                                       MemRefLayoutAttrInterface layout = {},
                                       Attribute memorySpace = {});

/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
/// canonicalize away once bufferization is finished.
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
                                unsigned addressSpace = 0);

} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir

#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h.inc"

namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {

/// AllocationHoistingBarrierOnly is an external implementation of
/// BufferizableOpInterface for ops that are (not yet) bufferizable, but are
/// known to be allocation hoisting barriers. All interface methods (except for
/// `isAllocationHoistingBarrier`) are implemented conservatively.
template <typename OpTy>
struct AllocationHoistingBarrierOnly
    : public BufferizableOpInterface::ExternalModel<
          AllocationHoistingBarrierOnly<OpTy>, OpTy> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
    return false;
  }

  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
                                                OpResult opResult) const {
    return {};
  }

  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
    return OpResult();
  }

  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
    return BufferRelation::None;
  }

  bool isWritable(Operation *op, Value value) const { return false; }

  LogicalResult bufferize(Operation *op, OpBuilder &b,
                          BufferizationState &state) const {
    auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
    if (any_of(op->getOperandTypes(), isaTensor) ||
        any_of(op->getResultTypes(), isaTensor))
      if (!state.options.allowUnknownOps)
        return op->emitError() << "unsupported op with tensors";

    for (Region &region : op->getRegions())
      if (failed(comprehensive_bufferize::bufferize(&region, state)))
        return failure();

    return success();
  }

  bool isAllocationHoistingBarrier(Operation *op) const { return true; }
};

} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
