//===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file define utilities that operate on builtin types and are
// useful across multiple dialects that use structured ops abstractions. These
// abstractions consist of define custom operations that encode and transport
// information about their semantics (e.g. type of iterators like parallel,
// reduction, etc..) as attributes.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
#define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H

#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"

namespace mlir {

class OpBuilder;

/// Tests whether the given maps describe a row major matmul. The test is
/// permutation-invariant. Note that this only checks the affine maps from an
/// operation, so does not perform any checks on the math being performed within
/// the reduction.
bool isRowMajorMatmul(ArrayAttr indexingMaps);

/// Tests whether the given maps describe a column major matmul. The test is
/// permutation-invariant. Note that this only checks the affine maps from an
/// operation, so does not perform any checks on the math being performed within
/// the reduction.
bool isColumnMajorMatmul(ArrayAttr indexingMaps);

/// Tests whether the given maps describe a row major batch matmul. The test is
/// permutation-invariant. Note that this only checks the affine maps from an
/// operation, so does not perform any checks on the math being performed within
/// the reduction.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);

/// Attribute name for the AffineArrayAttr which encodes the relationship
/// between a structured op iterators' and its operands.
constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }

/// Attribute name for the StrArrayAttr which encodes the type of a structured
/// op's iterators.
constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }

/// Attribute name for the StrArrayAttr which encodes the distribution type for
/// `linalg.tiled_loop`.
constexpr StringRef getDistributionTypesAttrName() {
  return "distribution_types";
}

/// Attribute name for the StringAttr which encodes an optional documentation
/// string of the structured op.
constexpr StringRef getDocAttrName() { return "doc"; }

/// Attribute name for the StrArrayAttr which encodes the external library
/// function that implements the structured op.
constexpr StringRef getLibraryCallAttrName() { return "library_call"; }

/// Attribute name for the StrArrayAttr which encodes the value of strides.
constexpr StringRef getStridesAttrName() { return "strides"; }

/// Attribute name for the StrArrayAttr which encodes the value of dilations.
constexpr StringRef getDilationsAttrName() { return "dilations"; }

/// Attribute name for the StrArrayAttr which encodes the value of paddings.
constexpr StringRef getPaddingAttrName() { return "padding"; }

/// Use to encode that a particular iterator type has parallel semantics.
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
inline bool isParallelIterator(Attribute attr) {
  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
  return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
}

/// Use to encode that a particular iterator type has reduction semantics.
constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
inline bool isReductionIterator(Attribute attr) {
  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
  return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
}

/// Use to encode that a particular iterator type has window semantics.
constexpr StringRef getWindowIteratorTypeName() { return "window"; }
inline bool isWindowIterator(Attribute attr) {
  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
  return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
}

/// Use to encode that a particular iterator type has window semantics.
inline ArrayRef<StringRef> getAllIteratorTypeNames() {
  static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
                                         getReductionIteratorTypeName(),
                                         getWindowIteratorTypeName()};
  return llvm::makeArrayRef(names);
}

/// Returns the iterator of a certain type.
inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
  auto names = getAllIteratorTypeNames();
  (void)names;
  assert(llvm::is_contained(names, name));
  return llvm::count_if(iteratorTypes, [name](Attribute a) {
    return a.cast<StringAttr>().getValue() == name;
  });
}

inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
  unsigned res = 0;
  for (auto n : getAllIteratorTypeNames())
    res += getNumIterators(n, iteratorTypes);
  return res;
}

/// Typed representation for loop type strings.
enum class IteratorType { Parallel, Reduction };

inline StringRef toString(IteratorType t) {
  switch (t) {
  case IteratorType::Parallel:
    return getParallelIteratorTypeName();
  case IteratorType::Reduction:
    return getReductionIteratorTypeName();
  }
  llvm_unreachable("Unsupported IteratorType");
}

/// Helper StructuredGenerator class to manipulate and rewrite ops with
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
/// yet implement the StructuredOpInterface itself.
template <typename StructuredOpInterface>
class StructuredGenerator {
public:
  using MapList = ArrayRef<ArrayRef<AffineExpr>>;

  struct IteratorType {
    IteratorType(StringRef strRef) : strRef(strRef) {}
    bool isOfType(Attribute attr) const {
      auto sAttr = attr.dyn_cast<StringAttr>();
      return sAttr && sAttr.getValue() == strRef;
    }
    StringRef strRef;
  };
  struct Par : public IteratorType {
    Par() : IteratorType(getParallelIteratorTypeName()) {}
  };
  struct Red : public IteratorType {
    Red() : IteratorType(getReductionIteratorTypeName()) {}
  };
  struct Win : public IteratorType {
    Win() : IteratorType(getWindowIteratorTypeName()) {}
  };

  StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
      : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
        iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}

  bool iters(ArrayRef<IteratorType> its) {
    if (its.size() != iterators.size())
      return false;
    for (int i = 0, e = its.size(); i != e; ++i) {
      if (!its[i].isOfType(iterators[i]))
        return false;
    }
    return true;
  }

  bool layout(MapList l) {
    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
    return maps == infer(l);
  }

protected:
  OpBuilder &builder;
  MLIRContext *ctx;
  Location loc;
  ArrayAttr iterators;
  SmallVector<AffineMap, 4> maps;
  Operation *op;
};

} // end namespace mlir

#endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H
