[SE] GlobalDeviceMemory owns its handle

Summary:
Final step in getting GlobalDeviceMemory to own its handle.

* Make GlobalDeviceMemory movable, but no longer copyable.
* Make Device::freeDeviceMemory function private and make
  GlobalDeviceMemoryBase a friend of Device so GlobalDeviceMemoryBase
  can free its memory in its destructor.
* Make GlobalDeviceMemory constructor private and make Device a friend
  so it can construct GlobalDeviceMemory.
* Remove SharedDeviceMemoryBase class because it is never used.
* Remove explicit memory freeing from example code.

This change just consumes any errors generated during device memory freeing.
The real error handling will be added in a future patch.

Reviewers: jlebar

Subscribers: jprice, parallel_libs-commits

Differential Revision: https://reviews.llvm.org/D24195

llvm-svn: 280509
GitOrigin-RevId: 31b88cb030fab7b35720c78798e7efff1596187a
diff --git a/streamexecutor/examples/Example.cpp b/streamexecutor/examples/Example.cpp
index 76027a8..a96648a 100644
--- a/streamexecutor/examples/Example.cpp
+++ b/streamexecutor/examples/Example.cpp
@@ -133,9 +133,5 @@
   for (size_t I = 0; I < ArraySize; ++I) {
     assert(HostX[I] == ExpectedX[I]);
   }
-
-  // Free device memory.
-  se::dieIfError(Device->freeDeviceMemory(X));
-  se::dieIfError(Device->freeDeviceMemory(Y));
   /// [Example saxpy host main]
 }
diff --git a/streamexecutor/include/streamexecutor/Device.h b/streamexecutor/include/streamexecutor/Device.h
index 26d0636..0171d06 100644
--- a/streamexecutor/include/streamexecutor/Device.h
+++ b/streamexecutor/include/streamexecutor/Device.h
@@ -56,13 +56,7 @@
         PDevice->allocateDeviceMemory(ElementCount * sizeof(T));
     if (!MaybeMemory)
       return MaybeMemory.takeError();
-    return GlobalDeviceMemory<T>::makeFromElementCount(*MaybeMemory,
-                                                       ElementCount);
-  }
-
-  /// Frees memory previously allocated with allocateDeviceMemory.
-  template <typename T> Error freeDeviceMemory(GlobalDeviceMemory<T> Memory) {
-    return PDevice->freeDeviceMemory(Memory.getHandle());
+    return GlobalDeviceMemory<T>(this, *MaybeMemory, ElementCount);
   }
 
   /// Allocates an array of ElementCount entries of type T in host memory.
@@ -304,6 +298,12 @@
   ///@} End host-synchronous device memory copying functions
 
 private:
+  // Only a GlobalDeviceMemoryBase may free device memory.
+  friend GlobalDeviceMemoryBase;
+  Error freeDeviceMemory(const GlobalDeviceMemoryBase &Memory) {
+    return PDevice->freeDeviceMemory(Memory.getHandle());
+  }
+
   PlatformDevice *PDevice;
 };
 
diff --git a/streamexecutor/include/streamexecutor/DeviceMemory.h b/streamexecutor/include/streamexecutor/DeviceMemory.h
index d841d26..b7cd3d1 100644
--- a/streamexecutor/include/streamexecutor/DeviceMemory.h
+++ b/streamexecutor/include/streamexecutor/DeviceMemory.h
@@ -14,24 +14,15 @@
 /// from the device. Host code cannot have a handle to device shared memory
 /// because that memory only exists during the execution of a kernel.
 ///
-/// GlobalDeviceMemoryBase is similar to a pair consisting of a void* pointer
-/// and a byte count to tell how much memory is pointed to by that void*.
+/// GlobalDeviceMemory<T> is a handle to an array of elements of type T in
+/// global device memory. It is similar to a pair of a std::unique_ptr<T> and an
+/// element count to tell how many elements of type T fit in the memory pointed
+/// to by that T*.
 ///
-/// GlobalDeviceMemory<T> is a subclass of GlobalDeviceMemoryBase which keeps
-/// track of the type of element to be stored in the device memory. It is
-/// similar to a pair of a T* pointer and an element count to tell how many
-/// elements of type T fit in the memory pointed to by that T*.
-///
-/// SharedDeviceMemoryBase is just the size in bytes of a shared memory buffer.
-///
-/// SharedDeviceMemory<T> is a subclass of SharedDeviceMemoryBase which knows
-/// how many elements of type T it can hold.
-///
-/// These classes are useful for keeping track of which memory space a buffer
-/// lives in, and the typed subclasses are useful for type-checking.
-///
-/// The typed subclass will be used by user code, and the untyped base classes
-/// will be used for type-unsafe operations inside of StreamExecutor.
+/// SharedDeviceMemory<T> is just the size in elements of an array of elements
+/// of type T in device shared memory. No resources are actually attached to
+/// this class, it is just like a memo to the device to allocate space in shared
+/// memory.
 ///
 //===----------------------------------------------------------------------===//
 
@@ -41,56 +32,11 @@
 #include <cassert>
 #include <cstddef>
 
+#include "streamexecutor/Utils/Error.h"
+
 namespace streamexecutor {
 
-/// Wrapper around a generic global device memory allocation.
-///
-/// This class represents a buffer of untyped bytes in the global memory space
-/// of a device. See GlobalDeviceMemory<T> for the corresponding type that
-/// includes type information for the elements in its buffer.
-///
-/// This is effectively a pair consisting of an opaque handle and a buffer size
-/// in bytes. The opaque handle is a platform-dependent handle to the actual
-/// memory that is allocated on the device.
-///
-/// In some cases, such as in the CUDA platform, the opaque handle may actually
-/// be a pointer in the virtual address space and it may be valid to perform
-/// arithmetic on it to obtain other device pointers, but this is not the case
-/// in general.
-///
-/// For example, in the OpenCL platform, the handle is a pointer to a _cl_mem
-/// handle object which really is completely opaque to the user.
-///
-/// The only fully platform-generic operations on handles are using them to
-/// create new GlobalDeviceMemoryBase objects, and comparing them to each other
-/// for equality.
-class GlobalDeviceMemoryBase {
-public:
-  /// Creates a GlobalDeviceMemoryBase from an optional handle and an optional
-  /// byte count.
-  explicit GlobalDeviceMemoryBase(const void *Handle = nullptr,
-                                  size_t ByteCount = 0)
-      : Handle(Handle), ByteCount(ByteCount) {}
-
-  /// Copyable like a pointer.
-  GlobalDeviceMemoryBase(const GlobalDeviceMemoryBase &) = default;
-
-  /// Copy-assignable like a pointer.
-  GlobalDeviceMemoryBase &operator=(const GlobalDeviceMemoryBase &) = default;
-
-  /// Returns the size, in bytes, for the backing memory.
-  size_t getByteCount() const { return ByteCount; }
-
-  /// Gets the internal handle.
-  ///
-  /// Warning: note that the pointer returned is not necessarily directly to
-  /// device virtual address space, but is platform-dependent.
-  const void *getHandle() const { return Handle; }
-
-private:
-  const void *Handle; // Platform-dependent value representing allocated memory.
-  size_t ByteCount;   // Size in bytes of this allocation.
-};
+class Device;
 
 template <typename ElemT> class GlobalDeviceMemory;
 
@@ -115,7 +61,7 @@
   }
 
   /// Gets the GlobalDeviceMemory backing this slice.
-  GlobalDeviceMemory<ElemT> getBaseMemory() const { return BaseMemory; }
+  const GlobalDeviceMemory<ElemT> &getBaseMemory() const { return BaseMemory; }
 
   /// Gets the offset of this slice from the base memory.
   ///
@@ -152,11 +98,68 @@
   }
 
 private:
-  GlobalDeviceMemory<ElemT> BaseMemory;
+  const GlobalDeviceMemory<ElemT> &BaseMemory;
   size_t ElementOffset;
   size_t ElementCount;
 };
 
+/// Wrapper around a generic global device memory allocation.
+///
+/// This class represents a buffer of untyped bytes in the global memory space
+/// of a device. See GlobalDeviceMemory<T> for the corresponding type that
+/// includes type information for the elements in its buffer.
+///
+/// This is effectively a pair consisting of an opaque handle and a buffer size
+/// in bytes. The opaque handle is a platform-dependent handle to the actual
+/// memory that is allocated on the device.
+///
+/// In some cases, such as in the CUDA platform, the opaque handle may actually
+/// be a pointer in the virtual address space and it may be valid to perform
+/// arithmetic on it to obtain other device pointers, but this is not the case
+/// in general.
+///
+/// For example, in the OpenCL platform, the handle is a pointer to a _cl_mem
+/// handle object which really is completely opaque to the user.
+class GlobalDeviceMemoryBase {
+public:
+  /// Returns an opaque handle to the underlying memory.
+  const void *getHandle() const { return Handle; }
+
+  // Cannot copy because the handle must be owned by a single object.
+  GlobalDeviceMemoryBase(const GlobalDeviceMemoryBase &) = delete;
+  GlobalDeviceMemoryBase &operator=(const GlobalDeviceMemoryBase &) = delete;
+
+protected:
+  /// Creates a GlobalDeviceMemoryBase from a handle and a byte count.
+  GlobalDeviceMemoryBase(Device *D, const void *Handle, size_t ByteCount)
+      : TheDevice(D), Handle(Handle), ByteCount(ByteCount) {}
+
+  /// Transfer ownership of the underlying handle.
+  GlobalDeviceMemoryBase(GlobalDeviceMemoryBase &&Other)
+      : TheDevice(Other.TheDevice), Handle(Other.Handle),
+        ByteCount(Other.ByteCount) {
+    Other.TheDevice = nullptr;
+    Other.Handle = nullptr;
+    Other.ByteCount = 0;
+  }
+
+  GlobalDeviceMemoryBase &operator=(GlobalDeviceMemoryBase &&Other) {
+    TheDevice = Other.TheDevice;
+    Handle = Other.Handle;
+    ByteCount = Other.ByteCount;
+    Other.TheDevice = nullptr;
+    Other.Handle = nullptr;
+    Other.ByteCount = 0;
+    return *this;
+  }
+
+  ~GlobalDeviceMemoryBase();
+
+  Device *TheDevice;  // Pointer to the device on which this memory lives.
+  const void *Handle; // Platform-dependent value representing allocated memory.
+  size_t ByteCount;   // Size in bytes of this allocation.
+};
+
 /// Typed wrapper around the "void *"-like GlobalDeviceMemoryBase class.
 ///
 /// For example, GlobalDeviceMemory<int> is a simple wrapper around
@@ -165,31 +168,12 @@
 template <typename ElemT>
 class GlobalDeviceMemory : public GlobalDeviceMemoryBase {
 public:
-  /// Creates a typed area of GlobalDeviceMemory with a given opaque handle and
-  /// the given element count.
-  static GlobalDeviceMemory<ElemT> makeFromElementCount(const void *Handle,
-                                                        size_t ElementCount) {
-    return GlobalDeviceMemory<ElemT>(Handle, ElementCount);
-  }
-
-  /// Creates a typed device memory region from an untyped device memory region.
-  ///
-  /// This effectively amounts to a cast from a void* to an ElemT*, but it also
-  /// manages the difference in the size measurements when
-  /// GlobalDeviceMemoryBase is measured in bytes and GlobalDeviceMemory is
-  /// measured in elements.
-  explicit GlobalDeviceMemory(const GlobalDeviceMemoryBase &Other)
-      : GlobalDeviceMemoryBase(Other.getHandle(), Other.getByteCount()) {}
-
-  /// Copyable like a pointer.
-  GlobalDeviceMemory(const GlobalDeviceMemory &) = default;
-
-  /// Copy-assignable like a pointer.
-  GlobalDeviceMemory &operator=(const GlobalDeviceMemory &) = default;
+  GlobalDeviceMemory(GlobalDeviceMemory &&Other) = default;
+  GlobalDeviceMemory &operator=(GlobalDeviceMemory &&Other) = default;
 
   /// Returns the number of elements of type ElemT that constitute this
   /// allocation.
-  size_t getElementCount() const { return getByteCount() / sizeof(ElemT); }
+  size_t getElementCount() const { return ByteCount / sizeof(ElemT); }
 
   /// Converts this memory object into a slice.
   GlobalDeviceMemorySlice<ElemT> asSlice() const {
@@ -197,23 +181,17 @@
   }
 
 private:
-  /// Constructs a GlobalDeviceMemory instance from an opaque handle and an
-  /// element count.
-  ///
-  /// This constructor is not public because there is a potential for confusion
-  /// between the size of the buffer in bytes and the size of the buffer in
-  /// elements.
-  ///
-  /// The static method makeFromElementCount is provided for users of this class
-  /// because its name makes the meaning of the size parameter clear.
-  GlobalDeviceMemory(const void *Handle, size_t ElementCount)
-      : GlobalDeviceMemoryBase(Handle, ElementCount * sizeof(ElemT)) {}
+  GlobalDeviceMemory(const GlobalDeviceMemory &) = delete;
+  GlobalDeviceMemory &operator=(const GlobalDeviceMemory &) = delete;
+
+  // Only a Device can create a GlobalDeviceMemory instance.
+  friend Device;
+  GlobalDeviceMemory(Device *D, const void *Handle, size_t ElementCount)
+      : GlobalDeviceMemoryBase(D, Handle, ElementCount * sizeof(ElemT)) {}
 };
 
-/// A class to represent the size of a dynamic shared memory buffer on a device.
-///
-/// This class maintains no information about the types to be stored in the
-/// buffer. For the typed version of this class see SharedDeviceMemory<ElemT>.
+/// A class to represent the size of a dynamic shared memory buffer of elements
+/// of type T on a device.
 ///
 /// Shared memory buffers exist only on the device and cannot be manipulated
 /// from the host, so instances of this class do not have an opaque handle, only
@@ -232,31 +210,7 @@
 /// multiple SharedDeviceMemory arguments, and simply adding together all the
 /// shared memory sizes to get the final shared memory size that is used to
 /// launch the kernel.
-class SharedDeviceMemoryBase {
-public:
-  /// Creates an untyped shared memory array from a byte count.
-  SharedDeviceMemoryBase(size_t ByteCount) : ByteCount(ByteCount) {}
-
-  /// Copyable because it is just an array size.
-  SharedDeviceMemoryBase(const SharedDeviceMemoryBase &) = default;
-
-  /// Copy-assignable because it is just an array size.
-  SharedDeviceMemoryBase &operator=(const SharedDeviceMemoryBase &) = default;
-
-  /// Gets the byte count.
-  size_t getByteCount() const { return ByteCount; }
-
-private:
-  size_t ByteCount;
-};
-
-/// Typed wrapper around the untyped SharedDeviceMemoryBase class.
-///
-/// For example, SharedDeviceMemory<int> is a wrapper around
-/// SharedDeviceMemoryBase that represents a buffer of integers stored in shared
-/// device memory.
-template <typename ElemT>
-class SharedDeviceMemory : public SharedDeviceMemoryBase {
+template <typename ElemT> class SharedDeviceMemory {
 public:
   /// Creates a typed area of shared device memory with a given number of
   /// elements.
@@ -272,7 +226,7 @@
 
   /// Returns the number of elements of type ElemT that can fit this memory
   /// buffer.
-  size_t getElementCount() const { return getByteCount() / sizeof(ElemT); }
+  size_t getElementCount() const { return ElementCount; }
 
   /// Returns whether this is a single-element memory buffer.
   bool isScalar() const { return getElementCount() == 1; }
@@ -287,7 +241,9 @@
   /// The static method makeFromElementCount is provided for users of this class
   /// because its name makes the meaning of the size parameter clear.
   explicit SharedDeviceMemory(size_t ElementCount)
-      : SharedDeviceMemoryBase(ElementCount * sizeof(ElemT)) {}
+      : ElementCount(ElementCount) {}
+
+  size_t ElementCount;
 };
 
 } // namespace streamexecutor
diff --git a/streamexecutor/lib/CMakeLists.txt b/streamexecutor/lib/CMakeLists.txt
index aa16f50..79ae5c7 100644
--- a/streamexecutor/lib/CMakeLists.txt
+++ b/streamexecutor/lib/CMakeLists.txt
@@ -7,6 +7,7 @@
     streamexecutor
     $<TARGET_OBJECTS:utils>
     Device.cpp
+    DeviceMemory.cpp
     Kernel.cpp
     KernelSpec.cpp
     PackedKernelArgumentArray.cpp
diff --git a/streamexecutor/lib/DeviceMemory.cpp b/streamexecutor/lib/DeviceMemory.cpp
new file mode 100644
index 0000000..62b702b
--- /dev/null
+++ b/streamexecutor/lib/DeviceMemory.cpp
@@ -0,0 +1,28 @@
+//===-- DeviceMemory.cpp - DeviceMemory implementation --------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Implementation of DeviceMemory class internals.
+///
+//===----------------------------------------------------------------------===//
+
+#include "streamexecutor/DeviceMemory.h"
+
+#include "streamexecutor/Device.h"
+
+namespace streamexecutor {
+
+GlobalDeviceMemoryBase::~GlobalDeviceMemoryBase() {
+  if (Handle) {
+    // TODO(jhen): How to handle errors here.
+    consumeError(TheDevice->freeDeviceMemory(*this));
+  }
+}
+
+} // namespace streamexecutor
diff --git a/streamexecutor/lib/unittests/DeviceTest.cpp b/streamexecutor/lib/unittests/DeviceTest.cpp
index 6e55aa5..593f1d1 100644
--- a/streamexecutor/lib/unittests/DeviceTest.cpp
+++ b/streamexecutor/lib/unittests/DeviceTest.cpp
@@ -78,7 +78,6 @@
   se::Expected<se::GlobalDeviceMemory<int>> MaybeMemory =
       Device.allocateDeviceMemory<int>(10);
   EXPECT_TRUE(static_cast<bool>(MaybeMemory));
-  EXPECT_NO_ERROR(Device.freeDeviceMemory(*MaybeMemory));
 }
 
 TEST_F(DeviceTest, AllocateAndFreeHostMemory) {