| //===----------- MemoryManager.h - Target independent memory manager ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Target independent memory manager. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H |
| #define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H |
| |
| #include <cassert> |
| #include <functional> |
| #include <list> |
| #include <mutex> |
| #include <set> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include "Shared/Debug.h" |
| #include "Shared/Utils.h" |
| #include "omptarget.h" |
| |
| #include "llvm/Support/Error.h" |
| |
| using namespace llvm::omp::target::debug; |
| |
| namespace llvm { |
| |
| /// Base class of per-device allocator. |
| class DeviceAllocatorTy { |
| public: |
| virtual ~DeviceAllocatorTy() = default; |
| |
| /// Allocate a memory of size \p Size . \p HstPtr is used to assist the |
| /// allocation. |
| virtual Expected<void *> |
| allocate(size_t Size, void *HstPtr, |
| TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0; |
| |
| /// Delete the pointer \p TgtPtr on the device |
| virtual Error free(void *TgtPtr, |
| TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0; |
| }; |
| |
| /// Class of memory manager. The memory manager is per-device by using |
| /// per-device allocator. Therefore, each plugin using memory manager should |
| /// have an allocator for each device. |
| class MemoryManagerTy { |
| static constexpr const size_t BucketSize[] = { |
| 0, 1U << 2, 1U << 3, 1U << 4, 1U << 5, 1U << 6, 1U << 7, |
| 1U << 8, 1U << 9, 1U << 10, 1U << 11, 1U << 12, 1U << 13}; |
| |
| static constexpr const int NumBuckets = |
| sizeof(BucketSize) / sizeof(BucketSize[0]); |
| |
| /// Find the previous number that is power of 2 given a number that is not |
| /// power of 2. |
| static size_t floorToPowerOfTwo(size_t Num) { |
| Num |= Num >> 1; |
| Num |= Num >> 2; |
| Num |= Num >> 4; |
| Num |= Num >> 8; |
| Num |= Num >> 16; |
| #if INTPTR_MAX == INT64_MAX |
| Num |= Num >> 32; |
| #elif INTPTR_MAX == INT32_MAX |
| // Do nothing with 32-bit |
| #else |
| #error Unsupported architecture |
| #endif |
| Num += 1; |
| return Num >> 1; |
| } |
| |
| /// Find a suitable bucket |
| static int findBucket(size_t Size) { |
| const size_t F = floorToPowerOfTwo(Size); |
| |
| ODBG(ODT_Alloc) << "findBucket: Size " << Size << " is floored to " << F |
| << "."; |
| |
| int L = 0, H = NumBuckets - 1; |
| while (H - L > 1) { |
| int M = (L + H) >> 1; |
| if (BucketSize[M] == F) |
| return M; |
| if (BucketSize[M] > F) |
| H = M - 1; |
| else |
| L = M; |
| } |
| |
| assert(L >= 0 && L < NumBuckets && "L is out of range"); |
| |
| ODBG(ODT_Alloc) << "findBucket: Size " << Size << " goes to bucket " << L; |
| |
| return L; |
| } |
| |
| /// A structure stores the meta data of a target pointer |
| struct NodeTy { |
| /// Memory size |
| const size_t Size; |
| /// Target pointer |
| void *Ptr; |
| |
| /// Constructor |
| NodeTy(size_t Size, void *Ptr) : Size(Size), Ptr(Ptr) {} |
| }; |
| |
| /// To make \p NodePtrTy ordered when they're put into \p std::multiset. |
| struct NodeCmpTy { |
| bool operator()(const NodeTy &LHS, const NodeTy &RHS) const { |
| return LHS.Size < RHS.Size; |
| } |
| }; |
| |
| /// A \p FreeList is a set of Nodes. We're using \p std::multiset here to make |
| /// the look up procedure more efficient. |
| using FreeListTy = std::multiset<std::reference_wrapper<NodeTy>, NodeCmpTy>; |
| |
| /// A list of \p FreeListTy entries, each of which is a \p std::multiset of |
| /// Nodes whose size is less or equal to a specific bucket size. |
| std::vector<FreeListTy> FreeLists; |
| /// A list of mutex for each \p FreeListTy entry |
| std::vector<std::mutex> FreeListLocks; |
| /// A table to map from a target pointer to its node |
| std::unordered_map<void *, NodeTy> PtrToNodeTable; |
| /// The mutex for the table \p PtrToNodeTable |
| std::mutex MapTableLock; |
| |
| /// The reference to a device allocator |
| DeviceAllocatorTy &DeviceAllocator; |
| |
| /// The threshold to manage memory using memory manager. If the request size |
| /// is larger than \p SizeThreshold, the allocation will not be managed by the |
| /// memory manager. |
| size_t SizeThreshold = 1U << 13; |
| |
| /// Request memory from target device |
| Expected<void *> allocateOnDevice(size_t Size, void *HstPtr) const { |
| return DeviceAllocator.allocate(Size, HstPtr, TARGET_ALLOC_DEVICE); |
| } |
| |
| /// Deallocate data on device |
| Error deleteOnDevice(void *Ptr) const { return DeviceAllocator.free(Ptr); } |
| |
| /// This function is called when it tries to allocate memory on device but the |
| /// device returns out of memory. It will first free all memory in the |
| /// FreeList and try to allocate again. |
| Expected<void *> freeAndAllocate(size_t Size, void *HstPtr) { |
| std::vector<void *> RemoveList; |
| |
| // Deallocate all memory in FreeList |
| for (int I = 0; I < NumBuckets; ++I) { |
| FreeListTy &List = FreeLists[I]; |
| std::lock_guard<std::mutex> Lock(FreeListLocks[I]); |
| if (List.empty()) |
| continue; |
| for (const NodeTy &N : List) { |
| if (auto Err = deleteOnDevice(N.Ptr)) |
| return Err; |
| RemoveList.push_back(N.Ptr); |
| } |
| FreeLists[I].clear(); |
| } |
| |
| // Remove all nodes in the map table which have been released |
| if (!RemoveList.empty()) { |
| std::lock_guard<std::mutex> LG(MapTableLock); |
| for (void *P : RemoveList) |
| PtrToNodeTable.erase(P); |
| } |
| |
| // Try allocate memory again |
| return allocateOnDevice(Size, HstPtr); |
| } |
| |
| /// The goal is to allocate memory on the device. It first tries to |
| /// allocate directly on the device. If a \p nullptr is returned, it might |
| /// be because the device is OOM. In that case, it will free all unused |
| /// memory and then try again. |
| Expected<void *> allocateOrFreeAndAllocateOnDevice(size_t Size, |
| void *HstPtr) { |
| auto TgtPtrOrErr = allocateOnDevice(Size, HstPtr); |
| if (!TgtPtrOrErr) |
| return TgtPtrOrErr.takeError(); |
| |
| void *TgtPtr = *TgtPtrOrErr; |
| // We cannot get memory from the device. It might be due to OOM. Let's |
| // free all memory in FreeLists and try again. |
| if (TgtPtr == nullptr) { |
| ODBG(ODT_Alloc) << "Failed to get memory on device. Free all memory " |
| << "in FreeLists and try again."; |
| TgtPtrOrErr = freeAndAllocate(Size, HstPtr); |
| if (!TgtPtrOrErr) |
| return TgtPtrOrErr.takeError(); |
| TgtPtr = *TgtPtrOrErr; |
| } |
| |
| if (TgtPtr == nullptr) |
| ODBG(ODT_Alloc) << "Still cannot get memory on device probably because " |
| << "the device is OOM."; |
| |
| return TgtPtr; |
| } |
| |
| public: |
| /// Constructor. If \p Threshold is non-zero, then the default threshold will |
| /// be overwritten by \p Threshold. |
| MemoryManagerTy(DeviceAllocatorTy &DeviceAllocator, size_t Threshold = 0) |
| : FreeLists(NumBuckets), FreeListLocks(NumBuckets), |
| DeviceAllocator(DeviceAllocator) { |
| if (Threshold) |
| SizeThreshold = Threshold; |
| } |
| |
| /// Destructor |
| ~MemoryManagerTy() { |
| for (auto &PtrToNode : PtrToNodeTable) { |
| assert(PtrToNode.second.Ptr && "nullptr in map table"); |
| if (auto Err = deleteOnDevice(PtrToNode.second.Ptr)) |
| REPORT() << "Failure to delete memory: " << toString(std::move(Err)); |
| } |
| } |
| |
| /// Allocate memory of size \p Size from target device. \p HstPtr is used to |
| /// assist the allocation. |
| Expected<void *> allocate(size_t Size, void *HstPtr) { |
| // If the size is zero, we will not bother the target device. Just return |
| // nullptr directly. |
| if (Size == 0) |
| return nullptr; |
| |
| ODBG(ODT_Alloc) << "MemoryManagerTy::allocate: size " << Size |
| << " with host pointer " << HstPtr << "."; |
| |
| // If the size is greater than the threshold, allocate it directly from |
| // device. |
| if (Size > SizeThreshold) { |
| ODBG(ODT_Alloc) << Size << " is greater than the threshold " |
| << SizeThreshold << ". Allocate it directly from device"; |
| auto TgtPtrOrErr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr); |
| if (!TgtPtrOrErr) |
| return TgtPtrOrErr.takeError(); |
| |
| ODBG(ODT_Alloc) << "Got target pointer " << *TgtPtrOrErr |
| << ". Return directly."; |
| |
| return *TgtPtrOrErr; |
| } |
| |
| NodeTy *NodePtr = nullptr; |
| |
| // Try to get a node from FreeList |
| { |
| const int B = findBucket(Size); |
| FreeListTy &List = FreeLists[B]; |
| |
| NodeTy TempNode(Size, nullptr); |
| std::lock_guard<std::mutex> LG(FreeListLocks[B]); |
| const auto Itr = List.find(TempNode); |
| |
| if (Itr != List.end()) { |
| NodePtr = &Itr->get(); |
| List.erase(Itr); |
| } |
| } |
| |
| if (NodePtr != nullptr) |
| ODBG(ODT_Alloc) << "Find one node " << NodePtr << " in the bucket."; |
| |
| // We cannot find a valid node in FreeLists. Let's allocate on device and |
| // create a node for it. |
| if (NodePtr == nullptr) { |
| ODBG(ODT_Alloc) << "Cannot find a node in the FreeLists. " |
| << "Allocate on device."; |
| // Allocate one on device |
| auto TgtPtrOrErr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr); |
| if (!TgtPtrOrErr) |
| return TgtPtrOrErr.takeError(); |
| |
| void *TgtPtr = *TgtPtrOrErr; |
| if (TgtPtr == nullptr) |
| return nullptr; |
| |
| // Create a new node and add it into the map table |
| { |
| std::lock_guard<std::mutex> Guard(MapTableLock); |
| auto Itr = PtrToNodeTable.emplace(TgtPtr, NodeTy(Size, TgtPtr)); |
| NodePtr = &Itr.first->second; |
| } |
| |
| ODBG(ODT_Alloc) << "Node address " << NodePtr << ", target pointer " |
| << TgtPtr << ", size " << Size; |
| } |
| |
| assert(NodePtr && "NodePtr should not be nullptr at this point"); |
| |
| return NodePtr->Ptr; |
| } |
| |
| /// Deallocate memory pointed by \p TgtPtr |
| Error free(void *TgtPtr) { |
| ODBG(ODT_Alloc) << "MemoryManagerTy::free: target memory " << TgtPtr << "."; |
| |
| NodeTy *P = nullptr; |
| |
| // Look it up into the table |
| { |
| std::lock_guard<std::mutex> G(MapTableLock); |
| auto Itr = PtrToNodeTable.find(TgtPtr); |
| |
| // We don't remove the node from the map table because the map does not |
| // change. |
| if (Itr != PtrToNodeTable.end()) |
| P = &Itr->second; |
| } |
| |
| // The memory is not managed by the manager |
| if (P == nullptr) { |
| ODBG(ODT_Alloc) << "Cannot find its node. Delete it on device directly."; |
| return deleteOnDevice(TgtPtr); |
| } |
| |
| // Insert the node to the free list |
| const int B = findBucket(P->Size); |
| |
| ODBG(ODT_Alloc) << "Found its node " << P << ". Insert it to bucket " << B |
| << "."; |
| |
| { |
| std::lock_guard<std::mutex> G(FreeListLocks[B]); |
| FreeLists[B].insert(*P); |
| } |
| |
| return Error::success(); |
| } |
| |
| /// Get the size threshold from the environment variable |
| /// \p LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD . Returns a <tt> |
| /// std::pair<size_t, bool> </tt> where the first element represents the |
| /// threshold and the second element represents whether user disables memory |
| /// manager explicitly by setting the var to 0. If user doesn't specify |
| /// anything, returns <0, true>. |
| static std::pair<size_t, bool> getSizeThresholdFromEnv() { |
| static UInt64Envar MemoryManagerThreshold( |
| "LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD", 0); |
| |
| size_t Threshold = MemoryManagerThreshold.get(); |
| |
| if (MemoryManagerThreshold.isPresent() && Threshold == 0) { |
| ODBG(ODT_Alloc) << "Disabled memory manager as user set " |
| << "LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD=0."; |
| return std::make_pair(0, false); |
| } |
| |
| return std::make_pair(Threshold, true); |
| } |
| }; |
| |
| // GCC still cannot handle the static data member like Clang so we still need |
| // this part. |
| constexpr const size_t MemoryManagerTy::BucketSize[]; |
| constexpr const int MemoryManagerTy::NumBuckets; |
| |
| } // namespace llvm |
| |
| #endif // LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H |