blob: f505731a649bc044aca85e1639083b89b3ae103c [file] [log] [blame]
//===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_STORAGEUNIQUER_H
#define MLIR_SUPPORT_STORAGEUNIQUER_H
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Allocator.h"
namespace mlir {
namespace detail {
struct StorageUniquerImpl;
/// Trait to check if ImplTy provides a 'getKey' method with types 'Args'.
template <typename ImplTy, typename... Args>
using has_impltype_getkey_t = decltype(ImplTy::getKey(std::declval<Args>()...));
/// Trait to check if ImplTy provides a 'hashKey' method for 'T'.
template <typename ImplTy, typename T>
using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
} // namespace detail
/// A utility class to get, or create instances of storage classes. These
/// storage classes must respect the following constraints:
/// - Derive from StorageUniquer::BaseStorage.
/// - Provide an unsigned 'kind' value to be used as part of the unique'ing
/// process.
///
/// For non-parametric storage classes, i.e. those that are solely uniqued by
/// their kind, nothing else is needed. Instances of these classes can be
/// created by calling `get` without trailing arguments.
///
/// Otherwise, the parametric storage classes may be created with `get`,
/// and must respect the following:
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
/// instance of the storage class within its kind.
/// * The key type must be constructible from the values passed into the
/// getComplex call after the kind.
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
/// storage class must define a hashing method:
/// 'static unsigned hashKey(const KeyTy &)'
///
/// - Provide a method, 'bool operator==(const KeyTy &) const', to
/// compare the storage instance against an instance of the key type.
///
/// - Provide a static construction method:
/// 'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)'
/// that builds a unique instance of the derived storage. The arguments to
/// this function are an allocator to store any uniqued data and the key
/// type for this storage.
///
/// - Provide a cleanup method:
/// 'void cleanup()'
/// that is called when erasing a storage instance. This should cleanup any
/// fields of the storage as necessary and not attempt to free the memory
/// of the storage itself.
class StorageUniquer {
public:
StorageUniquer();
~StorageUniquer();
/// This class acts as the base storage that all storage classes must derived
/// from.
class BaseStorage {
public:
/// Get the kind classification of this storage.
unsigned getKind() const { return kind; }
protected:
BaseStorage() : kind(0) {}
private:
/// Allow access to the kind field.
friend detail::StorageUniquerImpl;
/// Classification of the subclass, used for type checking.
unsigned kind;
};
/// This is a utility allocator used to allocate memory for instances of
/// derived types.
class StorageAllocator {
public:
/// Copy the specified array of elements into memory managed by our bump
/// pointer allocator. This assumes the elements are all PODs.
template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
if (elements.empty())
return llvm::None;
auto result = allocator.Allocate<T>(elements.size());
std::uninitialized_copy(elements.begin(), elements.end(), result);
return ArrayRef<T>(result, elements.size());
}
/// Copy the provided string into memory managed by our bump pointer
/// allocator.
StringRef copyInto(StringRef str) {
auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
return StringRef(result.data(), str.size());
}
/// Allocate an instance of the provided type.
template <typename T> T *allocate() { return allocator.Allocate<T>(); }
/// Allocate 'size' bytes of 'alignment' aligned memory.
void *allocate(size_t size, size_t alignment) {
return allocator.Allocate(size, alignment);
}
private:
/// The raw allocator for type storage objects.
llvm::BumpPtrAllocator allocator;
};
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
/// that can be used to initialize a newly inserted storage instance. This
/// function is used for derived types that have complex storage or uniquing
/// constraints.
template <typename Storage, typename Arg, typename... Args>
Storage *get(std::function<void(Storage *)> initFn, unsigned kind, Arg &&arg,
Args &&... args) {
// Construct a value of the derived key type.
auto derivedKey =
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
// Create a hash of the kind and the derived key.
unsigned hashValue = getHash<Storage>(kind, derivedKey);
// Generate an equality function for the derived storage.
std::function<bool(const BaseStorage *)> isEqual =
[&derivedKey](const BaseStorage *existing) {
return static_cast<const Storage &>(*existing) == derivedKey;
};
// Generate a constructor function for the derived storage.
std::function<BaseStorage *(StorageAllocator &)> ctorFn =
[&](StorageAllocator &allocator) {
auto *storage = Storage::construct(allocator, derivedKey);
if (initFn)
initFn(storage);
return storage;
};
// Get an instance for the derived storage.
return static_cast<Storage *>(getImpl(kind, hashValue, isEqual, ctorFn));
}
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
/// that can be used to initialize a newly inserted storage instance. This
/// function is used for derived types that use no additional storage or
/// uniquing outside of the kind.
template <typename Storage>
Storage *get(std::function<void(Storage *)> initFn, unsigned kind) {
auto ctorFn = [&](StorageAllocator &allocator) {
auto *storage = new (allocator.allocate<Storage>()) Storage();
if (initFn)
initFn(storage);
return storage;
};
return static_cast<Storage *>(getImpl(kind, ctorFn));
}
/// Erases a uniqued instance of 'Storage'. This function is used for derived
/// types that have complex storage or uniquing constraints.
template <typename Storage, typename Arg, typename... Args>
void erase(unsigned kind, Arg &&arg, Args &&... args) {
// Construct a value of the derived key type.
auto derivedKey =
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
// Create a hash of the kind and the derived key.
unsigned hashValue = getHash<Storage>(kind, derivedKey);
// Generate an equality function for the derived storage.
std::function<bool(const BaseStorage *)> isEqual =
[&derivedKey](const BaseStorage *existing) {
return static_cast<const Storage &>(*existing) == derivedKey;
};
// Attempt to erase the storage instance.
eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) {
static_cast<Storage *>(storage)->cleanup();
});
}
private:
/// Implementation for getting/creating an instance of a derived type with
/// complex storage.
BaseStorage *getImpl(unsigned kind, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
std::function<BaseStorage *(StorageAllocator &)> ctorFn);
/// Implementation for getting/creating an instance of a derived type with
/// default storage.
BaseStorage *getImpl(unsigned kind,
std::function<BaseStorage *(StorageAllocator &)> ctorFn);
/// Implementation for erasing an instance of a derived type with complex
/// storage.
void eraseImpl(unsigned kind, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
std::function<void(BaseStorage *)> cleanupFn);
/// The internal implementation class.
std::unique_ptr<detail::StorageUniquerImpl> impl;
//===--------------------------------------------------------------------===//
// Key Construction
//===--------------------------------------------------------------------===//
/// Used to construct an instance of 'ImplTy::KeyTy' if there is an
/// 'ImplTy::getKey' function for the provided arguments.
template <typename ImplTy, typename... Args>
static typename std::enable_if<
is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
typename ImplTy::KeyTy>::type
getKey(Args &&... args) {
return ImplTy::getKey(args...);
}
/// If there is no 'ImplTy::getKey' method, then we try to directly construct
/// the 'ImplTy::KeyTy' with the provided arguments.
template <typename ImplTy, typename... Args>
static typename std::enable_if<
!is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
typename ImplTy::KeyTy>::type
getKey(Args &&... args) {
return typename ImplTy::KeyTy(args...);
}
//===--------------------------------------------------------------------===//
// Key and Kind Hashing
//===--------------------------------------------------------------------===//
/// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
/// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
template <typename ImplTy, typename DerivedKey>
static typename std::enable_if<
is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
::llvm::hash_code>::type
getHash(unsigned kind, const DerivedKey &derivedKey) {
return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
}
/// If there is no 'ImplTy::hashKey' default to using the
/// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
template <typename ImplTy, typename DerivedKey>
static typename std::enable_if<
!is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
::llvm::hash_code>::type
getHash(unsigned kind, const DerivedKey &derivedKey) {
return llvm::hash_combine(
kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
}
};
} // end namespace mlir
#endif