| //===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_SUPPORT_STORAGEUNIQUER_H |
| #define MLIR_SUPPORT_STORAGEUNIQUER_H |
| |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Support/TypeID.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/Allocator.h" |
| |
| namespace mlir { |
| namespace detail { |
| struct StorageUniquerImpl; |
| |
| /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'. |
| template <typename ImplTy, typename... Args> |
| using has_impltype_getkey_t = decltype(ImplTy::getKey(std::declval<Args>()...)); |
| |
| /// Trait to check if ImplTy provides a 'hashKey' method for 'T'. |
| template <typename ImplTy, typename T> |
| using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>())); |
| } // namespace detail |
| |
| /// A utility class to get or create instances of "storage classes". These |
| /// storage classes must derive from 'StorageUniquer::BaseStorage'. |
| /// |
| /// For non-parametric storage classes, i.e. singleton classes, nothing else is |
| /// needed. Instances of these classes can be created by calling `get` without |
| /// trailing arguments. |
| /// |
| /// Otherwise, the parametric storage classes may be created with `get`, |
| /// and must respect the following: |
| /// - Define a type alias, KeyTy, to a type that uniquely identifies the |
| /// instance of the storage class. |
| /// * The key type must be constructible from the values passed into the |
| /// getComplex call. |
| /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the |
| /// storage class must define a hashing method: |
| /// 'static unsigned hashKey(const KeyTy &)' |
| /// |
| /// - Provide a method, 'bool operator==(const KeyTy &) const', to |
| /// compare the storage instance against an instance of the key type. |
| /// |
| /// - Provide a static construction method: |
| /// 'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)' |
| /// that builds a unique instance of the derived storage. The arguments to |
| /// this function are an allocator to store any uniqued data and the key |
| /// type for this storage. |
| /// |
| /// - Provide a cleanup method: |
| /// 'void cleanup()' |
| /// that is called when erasing a storage instance. This should cleanup any |
| /// fields of the storage as necessary and not attempt to free the memory |
| /// of the storage itself. |
| /// |
| /// Storage classes may have an optional mutable component, which must not take |
| /// part in the unique immutable key. In this case, storage classes may be |
| /// mutated with `mutate` and must additionally respect the following: |
| /// - Provide a mutation method: |
| /// 'LogicalResult mutate(StorageAllocator &, <...>)' |
| /// that is called when mutating a storage instance. The first argument is |
| /// an allocator to store any mutable data, and the remaining arguments are |
| /// forwarded from the call site. The storage can be mutated at any time |
| /// after creation. Care must be taken to avoid excessive mutation since |
| /// the allocated storage can keep containing previous states. The return |
| /// value of the function is used to indicate whether the mutation was |
| /// successful, e.g., to limit the number of mutations or enable deferred |
| /// one-time assignment of the mutable component. |
| /// |
| /// All storage classes must be registered with the uniquer via |
| /// `registerStorageType` using an appropriate unique `TypeID` for the storage |
| /// class. |
| class StorageUniquer { |
| public: |
| /// This class acts as the base storage that all storage classes must derived |
| /// from. |
| class alignas(8) BaseStorage { |
| protected: |
| BaseStorage() = default; |
| }; |
| |
| /// This is a utility allocator used to allocate memory for instances of |
| /// derived types. |
| class StorageAllocator { |
| public: |
| /// Copy the specified array of elements into memory managed by our bump |
| /// pointer allocator. This assumes the elements are all PODs. |
| template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) { |
| if (elements.empty()) |
| return llvm::None; |
| auto result = allocator.Allocate<T>(elements.size()); |
| std::uninitialized_copy(elements.begin(), elements.end(), result); |
| return ArrayRef<T>(result, elements.size()); |
| } |
| |
| /// Copy the provided string into memory managed by our bump pointer |
| /// allocator. |
| StringRef copyInto(StringRef str) { |
| if (str.empty()) |
| return StringRef(); |
| |
| char *result = allocator.Allocate<char>(str.size() + 1); |
| std::uninitialized_copy(str.begin(), str.end(), result); |
| result[str.size()] = 0; |
| return StringRef(result, str.size()); |
| } |
| |
| /// Allocate an instance of the provided type. |
| template <typename T> T *allocate() { return allocator.Allocate<T>(); } |
| |
| /// Allocate 'size' bytes of 'alignment' aligned memory. |
| void *allocate(size_t size, size_t alignment) { |
| return allocator.Allocate(size, alignment); |
| } |
| |
| /// Returns true if this allocator allocated the provided object pointer. |
| bool allocated(const void *ptr) { |
| return allocator.identifyObject(ptr).hasValue(); |
| } |
| |
| private: |
| /// The raw allocator for type storage objects. |
| llvm::BumpPtrAllocator allocator; |
| }; |
| |
| StorageUniquer(); |
| ~StorageUniquer(); |
| |
| /// Set the flag specifying if multi-threading is disabled within the uniquer. |
| void disableMultithreading(bool disable = true); |
| |
| /// Register a new parametric storage class, this is necessary to create |
| /// instances of this class type. `id` is the type identifier that will be |
| /// used to identify this type when creating instances of it via 'get'. |
| template <typename Storage> void registerParametricStorageType(TypeID id) { |
| // If the storage is trivially destructible, we don't need a destructor |
| // function. |
| if (std::is_trivially_destructible<Storage>::value) |
| return registerParametricStorageTypeImpl(id, nullptr); |
| registerParametricStorageTypeImpl(id, [](BaseStorage *storage) { |
| static_cast<Storage *>(storage)->~Storage(); |
| }); |
| } |
| /// Utility override when the storage type represents the type id. |
| template <typename Storage> void registerParametricStorageType() { |
| registerParametricStorageType<Storage>(TypeID::get<Storage>()); |
| } |
| /// Register a new singleton storage class, this is necessary to get the |
| /// singletone instance. `id` is the type identifier that will be used to |
| /// access the singleton instance via 'get'. An optional initialization |
| /// function may also be provided to initialize the newly created storage |
| /// instance, and used when the singleton instance is created. |
| template <typename Storage> |
| void registerSingletonStorageType(TypeID id, |
| function_ref<void(Storage *)> initFn) { |
| auto ctorFn = [&](StorageAllocator &allocator) { |
| auto *storage = new (allocator.allocate<Storage>()) Storage(); |
| if (initFn) |
| initFn(storage); |
| return storage; |
| }; |
| registerSingletonImpl(id, ctorFn); |
| } |
| template <typename Storage> void registerSingletonStorageType(TypeID id) { |
| registerSingletonStorageType<Storage>(id, llvm::None); |
| } |
| /// Utility override when the storage type represents the type id. |
| template <typename Storage> |
| void registerSingletonStorageType(function_ref<void(Storage *)> initFn = {}) { |
| registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn); |
| } |
| |
| /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when |
| /// registering the storage instance. 'initFn' is an optional parameter that |
| /// can be used to initialize a newly inserted storage instance. This function |
| /// is used for derived types that have complex storage or uniquing |
| /// constraints. |
| template <typename Storage, typename... Args> |
| Storage *get(function_ref<void(Storage *)> initFn, TypeID id, |
| Args &&...args) { |
| // Construct a value of the derived key type. |
| auto derivedKey = getKey<Storage>(std::forward<Args>(args)...); |
| |
| // Create a hash of the derived key. |
| unsigned hashValue = getHash<Storage>(derivedKey); |
| |
| // Generate an equality function for the derived storage. |
| auto isEqual = [&derivedKey](const BaseStorage *existing) { |
| return static_cast<const Storage &>(*existing) == derivedKey; |
| }; |
| |
| // Generate a constructor function for the derived storage. |
| auto ctorFn = [&](StorageAllocator &allocator) { |
| auto *storage = Storage::construct(allocator, derivedKey); |
| if (initFn) |
| initFn(storage); |
| return storage; |
| }; |
| |
| // Get an instance for the derived storage. |
| return static_cast<Storage *>( |
| getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn)); |
| } |
| /// Utility override when the storage type represents the type id. |
| template <typename Storage, typename... Args> |
| Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) { |
| return get<Storage>(initFn, TypeID::get<Storage>(), |
| std::forward<Args>(args)...); |
| } |
| |
| /// Gets a uniqued instance of 'Storage' which is a singleton storage type. |
| /// 'id' is the type id used when registering the storage instance. |
| template <typename Storage> Storage *get(TypeID id) { |
| return static_cast<Storage *>(getSingletonImpl(id)); |
| } |
| /// Utility override when the storage type represents the type id. |
| template <typename Storage> Storage *get() { |
| return get<Storage>(TypeID::get<Storage>()); |
| } |
| |
| /// Test if there is a singleton storage uniquer initialized for the provided |
| /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer |
| /// is initialized when a dialect is loaded. |
| bool isSingletonStorageInitialized(TypeID id); |
| |
| /// Test if there is a parametric storage uniquer initialized for the provided |
| /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer |
| /// is initialized when a dialect is loaded. |
| bool isParametricStorageInitialized(TypeID id); |
| |
| /// Changes the mutable component of 'storage' by forwarding the trailing |
| /// arguments to the 'mutate' function of the derived class. |
| template <typename Storage, typename... Args> |
| LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) { |
| auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult { |
| return static_cast<Storage &>(*storage).mutate( |
| allocator, std::forward<Args>(args)...); |
| }; |
| return mutateImpl(id, storage, mutationFn); |
| } |
| |
| private: |
| /// Implementation for getting/creating an instance of a derived type with |
| /// parametric storage. |
| BaseStorage *getParametricStorageTypeImpl( |
| TypeID id, unsigned hashValue, |
| function_ref<bool(const BaseStorage *)> isEqual, |
| function_ref<BaseStorage *(StorageAllocator &)> ctorFn); |
| |
| /// Implementation for registering an instance of a derived type with |
| /// parametric storage. This method takes an optional destructor function that |
| /// destructs storage instances when necessary. |
| void registerParametricStorageTypeImpl( |
| TypeID id, function_ref<void(BaseStorage *)> destructorFn); |
| |
| /// Implementation for getting an instance of a derived type with default |
| /// storage. |
| BaseStorage *getSingletonImpl(TypeID id); |
| |
| /// Implementation for registering an instance of a derived type with default |
| /// storage. |
| void |
| registerSingletonImpl(TypeID id, |
| function_ref<BaseStorage *(StorageAllocator &)> ctorFn); |
| |
| /// Implementation for mutating an instance of a derived storage. |
| LogicalResult |
| mutateImpl(TypeID id, BaseStorage *storage, |
| function_ref<LogicalResult(StorageAllocator &)> mutationFn); |
| |
| /// The internal implementation class. |
| std::unique_ptr<detail::StorageUniquerImpl> impl; |
| |
| //===--------------------------------------------------------------------===// |
| // Key Construction |
| //===--------------------------------------------------------------------===// |
| |
| /// Used to construct an instance of 'ImplTy::KeyTy' if there is an |
| /// 'ImplTy::getKey' function for the provided arguments. |
| template <typename ImplTy, typename... Args> |
| static typename std::enable_if< |
| llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value, |
| typename ImplTy::KeyTy>::type |
| getKey(Args &&...args) { |
| return ImplTy::getKey(args...); |
| } |
| /// If there is no 'ImplTy::getKey' method, then we try to directly construct |
| /// the 'ImplTy::KeyTy' with the provided arguments. |
| template <typename ImplTy, typename... Args> |
| static typename std::enable_if< |
| !llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value, |
| typename ImplTy::KeyTy>::type |
| getKey(Args &&...args) { |
| return typename ImplTy::KeyTy(args...); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Key Hashing |
| //===--------------------------------------------------------------------===// |
| |
| /// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if |
| /// there is an 'ImplTy::hashKey' overload for 'DerivedKey'. |
| template <typename ImplTy, typename DerivedKey> |
| static typename std::enable_if< |
| llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value, |
| ::llvm::hash_code>::type |
| getHash(const DerivedKey &derivedKey) { |
| return ImplTy::hashKey(derivedKey); |
| } |
| /// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo' |
| /// definition for 'DerivedKey' for generating a hash. |
| template <typename ImplTy, typename DerivedKey> |
| static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t, |
| ImplTy, DerivedKey>::value, |
| ::llvm::hash_code>::type |
| getHash(const DerivedKey &derivedKey) { |
| return DenseMapInfo<DerivedKey>::getHashValue(derivedKey); |
| } |
| }; |
| } // end namespace mlir |
| |
| #endif |