//===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines utility classes for interfacing with StorageUniquer.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
#define MLIR_IR_STORAGEUNIQUERSUPPORT_H

#include "mlir/Support/InterfaceSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StorageUniquer.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/FunctionExtras.h"

namespace mlir {
class InFlightDiagnostic;
class Location;
class MLIRContext;

namespace detail {
/// Utility method to generate a callback that can be used to generate a
/// diagnostic when checking the construction invariants of a storage object.
/// This is defined out-of-line to avoid the need to include Location.h.
llvm::unique_function<InFlightDiagnostic()>
getDefaultDiagnosticEmitFn(MLIRContext *ctx);
llvm::unique_function<InFlightDiagnostic()>
getDefaultDiagnosticEmitFn(const Location &loc);

//===----------------------------------------------------------------------===//
// StorageUserTraitBase
//===----------------------------------------------------------------------===//

/// Helper class for implementing traits for storage classes. Clients are not
/// expected to interact with this directly, so its members are all protected.
template <typename ConcreteType, template <typename> class TraitType>
class StorageUserTraitBase {
protected:
  /// Return the derived instance.
  ConcreteType getInstance() const {
    // We have to cast up to the trait type, then to the concrete type because
    // the concrete type will multiply derive from the (content free) TraitBase
    // class, and we need to be able to disambiguate the path for the C++
    // compiler.
    auto *trait = static_cast<const TraitType<ConcreteType> *>(this);
    return *static_cast<const ConcreteType *>(trait);
  }
};

//===----------------------------------------------------------------------===//
// StorageUserBase
//===----------------------------------------------------------------------===//

namespace storage_user_base_impl {
/// Returns true if this given Trait ID matches the IDs of any of the provided
/// trait types `Traits`.
template <template <typename T> class... Traits>
bool hasTrait(TypeID traitID) {
  TypeID traitIDs[] = {TypeID::get<Traits>()...};
  for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
    if (traitIDs[i] == traitID)
      return true;
  return false;
}

// We specialize for the empty case to not define an empty array.
template <>
inline bool hasTrait(TypeID traitID) {
  return false;
}
} // namespace storage_user_base_impl

/// Utility class for implementing users of storage classes uniqued by a
/// StorageUniquer. Clients are not expected to interact with this class
/// directly.
template <typename ConcreteT, typename BaseT, typename StorageT,
          typename UniquerT, template <typename T> class... Traits>
class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
public:
  using BaseT::BaseT;

  /// Utility declarations for the concrete attribute class.
  using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
  using ImplType = StorageT;
  using HasTraitFn = bool (*)(TypeID);

  /// Return a unique identifier for the concrete type.
  static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }

  /// Provide an implementation of 'classof' that compares the type id of the
  /// provided value with that of the concrete type.
  template <typename T> static bool classof(T val) {
    static_assert(std::is_convertible<ConcreteT, T>::value,
                  "casting from a non-convertible type");
    return val.getTypeID() == getTypeID();
  }

  /// Returns an interface map for the interfaces registered to this storage
  /// user. This should not be used directly.
  static detail::InterfaceMap getInterfaceMap() {
    return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
  }

  /// Returns the function that returns true if the given Trait ID matches the
  /// IDs of any of the traits defined by the storage user.
  static HasTraitFn getHasTraitFn() {
    return [](TypeID id) {
      return storage_user_base_impl::hasTrait<Traits...>(id);
    };
  }

  /// Attach the given models as implementations of the corresponding interfaces
  /// for the concrete storage user class. The type must be registered with the
  /// context, i.e. the dialect to which the type belongs must be loaded. The
  /// call will abort otherwise.
  template <typename... IfaceModels>
  static void attachInterface(MLIRContext &context) {
    typename ConcreteT::AbstractTy *abstract =
        ConcreteT::AbstractTy::lookupMutable(TypeID::get<ConcreteT>(),
                                             &context);
    if (!abstract)
      llvm::report_fatal_error("Registering an interface for an attribute/type "
                               "that is not itself registered.");
    abstract->interfaceMap.template insert<IfaceModels...>();
  }

  /// Get or create a new ConcreteT instance within the ctx. This
  /// function is guaranteed to return a non null object and will assert if
  /// the arguments provided are invalid.
  template <typename... Args>
  static ConcreteT get(MLIRContext *ctx, Args... args) {
    // Ensure that the invariants are correct for construction.
    assert(
        succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
    return UniquerT::template get<ConcreteT>(ctx, args...);
  }

  /// Get or create a new ConcreteT instance within the ctx, defined at
  /// the given, potentially unknown, location. If the arguments provided are
  /// invalid, errors are emitted using the provided location and a null object
  /// is returned.
  template <typename... Args>
  static ConcreteT getChecked(const Location &loc, Args... args) {
    return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
  }

  /// Get or create a new ConcreteT instance within the ctx. If the arguments
  /// provided are invalid, errors are emitted using the provided `emitError`
  /// and a null object is returned.
  template <typename... Args>
  static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                              MLIRContext *ctx, Args... args) {
    // If the construction invariants fail then we return a null attribute.
    if (failed(ConcreteT::verify(emitErrorFn, args...)))
      return ConcreteT();
    return UniquerT::template get<ConcreteT>(ctx, args...);
  }

  /// Get an instance of the concrete type from a void pointer.
  static ConcreteT getFromOpaquePointer(const void *ptr) {
    return ConcreteT((const typename BaseT::ImplType *)ptr);
  }

protected:
  /// Mutate the current storage instance. This will not change the unique key.
  /// The arguments are forwarded to 'ConcreteT::mutate'.
  template <typename... Args> LogicalResult mutate(Args &&...args) {
    return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
                                                std::forward<Args>(args)...);
  }

  /// Default implementation that just returns success.
  template <typename... Args> static LogicalResult verify(Args... args) {
    return success();
  }

  /// Utility for easy access to the storage instance.
  ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
};
} // namespace detail
} // namespace mlir

#endif
