[SE] Remove Platform*Handle classes

Summary:
As pointed out by jprice, these classes don't serve a purpose. Instead,
we stay consistent with the way memory is managed and let the Stream and
Kernel classes directly hold opaque handles to device Stream and Kernel
instances, respectively.

Reviewers: jprice, jlebar

Subscribers: parallel_libs-commits

Differential Revision: https://reviews.llvm.org/D24213

llvm-svn: 280719
GitOrigin-RevId: 18ea094df15dfb8b41d4798bb1cbb14c434fa98d
diff --git a/streamexecutor/include/streamexecutor/Device.h b/streamexecutor/include/streamexecutor/Device.h
index 95d9b5c..0ee2b2f 100644
--- a/streamexecutor/include/streamexecutor/Device.h
+++ b/streamexecutor/include/streamexecutor/Device.h
@@ -35,12 +35,11 @@
   Expected<typename std::enable_if<std::is_base_of<KernelBase, KernelT>::value,
                                    KernelT>::type>
   createKernel(const MultiKernelLoaderSpec &Spec) {
-    Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle =
-        PDevice->createKernel(Spec);
+    Expected<const void *> MaybeKernelHandle = PDevice->createKernel(Spec);
     if (!MaybeKernelHandle) {
       return MaybeKernelHandle.takeError();
     }
-    return KernelT(Spec.getKernelName(), std::move(*MaybeKernelHandle));
+    return KernelT(PDevice, *MaybeKernelHandle, Spec.getKernelName());
   }
 
   /// Creates a stream object for this device.
diff --git a/streamexecutor/include/streamexecutor/Kernel.h b/streamexecutor/include/streamexecutor/Kernel.h
index c9b4180..6ea7c36 100644
--- a/streamexecutor/include/streamexecutor/Kernel.h
+++ b/streamexecutor/include/streamexecutor/Kernel.h
@@ -28,19 +28,32 @@
 
 namespace streamexecutor {
 
-class PlatformKernelHandle;
+class PlatformDevice;
 
 /// The base class for all kernel types.
 ///
 /// Stores the name of the kernel in both mangled and demangled forms.
 class KernelBase {
 public:
-  KernelBase(llvm::StringRef Name);
+  KernelBase(PlatformDevice *D, const void *PlatformKernelHandle,
+             llvm::StringRef Name);
 
+  KernelBase(const KernelBase &Other) = delete;
+  KernelBase &operator=(const KernelBase &Other) = delete;
+
+  KernelBase(KernelBase &&Other);
+  KernelBase &operator=(KernelBase &&Other);
+
+  ~KernelBase();
+
+  const void *getPlatformHandle() const { return PlatformKernelHandle; }
   const std::string &getName() const { return Name; }
   const std::string &getDemangledName() const { return DemangledName; }
 
 private:
+  PlatformDevice *PDevice;
+  const void *PlatformKernelHandle;
+
   std::string Name;
   std::string DemangledName;
 };
@@ -51,17 +64,12 @@
 /// function.
 template <typename... ParameterTs> class Kernel : public KernelBase {
 public:
-  Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle)
-      : KernelBase(Name), PHandle(std::move(PHandle)) {}
+  Kernel(PlatformDevice *D, const void *PlatformKernelHandle,
+         llvm::StringRef Name)
+      : KernelBase(D, PlatformKernelHandle, Name) {}
 
   Kernel(Kernel &&Other) = default;
   Kernel &operator=(Kernel &&Other) = default;
-
-  /// Gets the underlying platform-specific handle for this kernel.
-  PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }
-
-private:
-  std::unique_ptr<PlatformKernelHandle> PHandle;
 };
 
 } // namespace streamexecutor
diff --git a/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/streamexecutor/include/streamexecutor/PlatformInterfaces.h
index b3deff3..946f8f9 100644
--- a/streamexecutor/include/streamexecutor/PlatformInterfaces.h
+++ b/streamexecutor/include/streamexecutor/PlatformInterfaces.h
@@ -31,34 +31,6 @@
 
 namespace streamexecutor {
 
-class PlatformDevice;
-
-/// Platform-specific kernel handle.
-class PlatformKernelHandle {
-public:
-  explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
-
-  virtual ~PlatformKernelHandle();
-
-  PlatformDevice *getDevice() { return PDevice; }
-
-private:
-  PlatformDevice *PDevice;
-};
-
-/// Platform-specific stream handle.
-class PlatformStreamHandle {
-public:
-  explicit PlatformStreamHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
-
-  virtual ~PlatformStreamHandle();
-
-  PlatformDevice *getDevice() { return PDevice; }
-
-private:
-  PlatformDevice *PDevice;
-};
-
 /// Raw executor methods that must be implemented by each platform.
 ///
 /// This class defines the platform interface that supports executing work on a
@@ -73,19 +45,30 @@
   virtual std::string getName() const = 0;
 
   /// Creates a platform-specific kernel.
-  virtual Expected<std::unique_ptr<PlatformKernelHandle>>
+  virtual Expected<const void *>
   createKernel(const MultiKernelLoaderSpec &Spec) {
     return make_error("createKernel not implemented for platform " + getName());
   }
 
+  virtual Error destroyKernel(const void *Handle) {
+    return make_error("destroyKernel not implemented for platform " +
+                      getName());
+  }
+
   /// Creates a platform-specific stream.
-  virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
+  virtual Expected<const void *> createStream() {
     return make_error("createStream not implemented for platform " + getName());
   }
 
+  virtual Error destroyStream(const void *Handle) {
+    return make_error("destroyStream not implemented for platform " +
+                      getName());
+  }
+
   /// Launches a kernel on the given stream.
-  virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
-                       GridDimensions GridSize, PlatformKernelHandle *K,
+  virtual Error launch(const void *PlatformStreamHandle,
+                       BlockDimensions BlockSize, GridDimensions GridSize,
+                       const void *PKernelHandle,
                        const PackedKernelArgumentArrayBase &ArgumentArray) {
     return make_error("launch not implemented for platform " + getName());
   }
@@ -94,9 +77,9 @@
   ///
   /// HostDst should have been allocated by allocateHostMemory or registered
   /// with registerHostMemory.
-  virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle,
-                        size_t SrcByteOffset, void *HostDst,
-                        size_t DstByteOffset, size_t ByteCount) {
+  virtual Error copyD2H(const void *PlatformStreamHandle,
+                        const void *DeviceSrcHandle, size_t SrcByteOffset,
+                        void *HostDst, size_t DstByteOffset, size_t ByteCount) {
     return make_error("copyD2H not implemented for platform " + getName());
   }
 
@@ -104,22 +87,23 @@
   ///
   /// HostSrc should have been allocated by allocateHostMemory or registered
   /// with registerHostMemory.
-  virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc,
+  virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc,
                         size_t SrcByteOffset, const void *DeviceDstHandle,
                         size_t DstByteOffset, size_t ByteCount) {
     return make_error("copyH2D not implemented for platform " + getName());
   }
 
   /// Copies data from one device location to another.
-  virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle,
-                        size_t SrcByteOffset, const void *DeviceDstHandle,
-                        size_t DstByteOffset, size_t ByteCount) {
+  virtual Error copyD2D(const void *PlatformStreamHandle,
+                        const void *DeviceSrcHandle, size_t SrcByteOffset,
+                        const void *DeviceDstHandle, size_t DstByteOffset,
+                        size_t ByteCount) {
     return make_error("copyD2D not implemented for platform " + getName());
   }
 
   /// Blocks the host until the given stream completes all the work enqueued up
   /// to the point this function is called.
-  virtual Error blockHostUntilDone(PlatformStreamHandle *S) {
+  virtual Error blockHostUntilDone(const void *PlatformStreamHandle) {
     return make_error("blockHostUntilDone not implemented for platform " +
                       getName());
   }
diff --git a/streamexecutor/include/streamexecutor/Stream.h b/streamexecutor/include/streamexecutor/Stream.h
index 81f9ada..48dcf32 100644
--- a/streamexecutor/include/streamexecutor/Stream.h
+++ b/streamexecutor/include/streamexecutor/Stream.h
@@ -59,10 +59,13 @@
 /// of a stream once it is in an error state.
 class Stream {
 public:
-  explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream);
+  Stream(PlatformDevice *D, const void *PlatformStreamHandle);
 
-  Stream(Stream &&Other) = default;
-  Stream &operator=(Stream &&Other) = default;
+  Stream(const Stream &Other) = delete;
+  Stream &operator=(const Stream &Other) = delete;
+
+  Stream(Stream &&Other);
+  Stream &operator=(Stream &&Other);
 
   ~Stream();
 
@@ -88,7 +91,7 @@
   //
   // Returns the result of getStatus() after the Stream work completes.
   Error blockHostUntilDone() {
-    setError(PDevice->blockHostUntilDone(ThePlatformStream.get()));
+    setError(PDevice->blockHostUntilDone(PlatformStreamHandle));
     return getStatus();
   }
 
@@ -105,7 +108,7 @@
                      const ParameterTs &... Arguments) {
     auto ArgumentArray =
         make_kernel_argument_pack<ParameterTs...>(Arguments...);
-    setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize,
+    setError(PDevice->launch(PlatformStreamHandle, BlockSize, GridSize,
                              K.getPlatformHandle(), ArgumentArray));
     return *this;
   }
@@ -136,7 +139,7 @@
       setError("copying too many elements, " + llvm::Twine(ElementCount) +
                ", to a host array of element count " + llvm::Twine(Dst.size()));
     else
-      setError(PDevice->copyD2H(ThePlatformStream.get(),
+      setError(PDevice->copyD2H(PlatformStreamHandle,
                                 Src.getBaseMemory().getHandle(),
                                 Src.getElementOffset() * sizeof(T), Dst.data(),
                                 0, ElementCount * sizeof(T)));
@@ -196,10 +199,9 @@
                ", to a device array of element count " +
                llvm::Twine(Dst.getElementCount()));
     else
-      setError(PDevice->copyH2D(ThePlatformStream.get(), Src.data(), 0,
-                                Dst.getBaseMemory().getHandle(),
-                                Dst.getElementOffset() * sizeof(T),
-                                ElementCount * sizeof(T)));
+      setError(PDevice->copyH2D(
+          PlatformStreamHandle, Src.data(), 0, Dst.getBaseMemory().getHandle(),
+          Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
     return *this;
   }
 
@@ -254,7 +256,7 @@
                llvm::Twine(Dst.getElementCount()));
     else
       setError(PDevice->copyD2D(
-          ThePlatformStream.get(), Src.getBaseMemory().getHandle(),
+          PlatformStreamHandle, Src.getBaseMemory().getHandle(),
           Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(),
           Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
     return *this;
@@ -342,7 +344,7 @@
   PlatformDevice *PDevice;
 
   /// The platform-specific stream handle for this instance.
-  std::unique_ptr<PlatformStreamHandle> ThePlatformStream;
+  const void *PlatformStreamHandle;
 
   /// Mutex that guards the error state flags.
   std::unique_ptr<llvm::sys::RWMutex> ErrorMessageMutex;
@@ -350,9 +352,6 @@
   /// First error message for an operation in this stream or empty if there have
   /// been no errors.
   llvm::Optional<std::string> ErrorMessage;
-
-  Stream(const Stream &) = delete;
-  void operator=(const Stream &) = delete;
 };
 
 } // namespace streamexecutor
diff --git a/streamexecutor/lib/Device.cpp b/streamexecutor/lib/Device.cpp
index 54f0384..0d81fb7 100644
--- a/streamexecutor/lib/Device.cpp
+++ b/streamexecutor/lib/Device.cpp
@@ -28,14 +28,11 @@
 Device::~Device() = default;
 
 Expected<Stream> Device::createStream() {
-  Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream =
-      PDevice->createStream();
+  Expected<const void *> MaybePlatformStream = PDevice->createStream();
   if (!MaybePlatformStream) {
     return MaybePlatformStream.takeError();
   }
-  assert((*MaybePlatformStream)->getDevice() == PDevice &&
-         "an executor created a stream with a different stored executor");
-  return Stream(std::move(*MaybePlatformStream));
+  return Stream(PDevice, *MaybePlatformStream);
 }
 
 } // namespace streamexecutor
diff --git a/streamexecutor/lib/Kernel.cpp b/streamexecutor/lib/Kernel.cpp
index 1f4218c..6130537 100644
--- a/streamexecutor/lib/Kernel.cpp
+++ b/streamexecutor/lib/Kernel.cpp
@@ -12,16 +12,49 @@
 ///
 //===----------------------------------------------------------------------===//
 
-#include "streamexecutor/Kernel.h"
+#include <cassert>
+
 #include "streamexecutor/Device.h"
+#include "streamexecutor/Kernel.h"
 #include "streamexecutor/PlatformInterfaces.h"
 
 #include "llvm/DebugInfo/Symbolize/Symbolize.h"
 
 namespace streamexecutor {
 
-KernelBase::KernelBase(llvm::StringRef Name)
-    : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
-                      Name, nullptr)) {}
+KernelBase::KernelBase(PlatformDevice *D, const void *PlatformKernelHandle,
+                       llvm::StringRef Name)
+    : PDevice(D), PlatformKernelHandle(PlatformKernelHandle), Name(Name),
+      DemangledName(
+          llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr)) {
+  assert(D != nullptr &&
+         "cannot construct a kernel object with a null platform device");
+  assert(PlatformKernelHandle != nullptr &&
+         "cannot construct a kernel object with a null platform kernel handle");
+}
+
+KernelBase::KernelBase(KernelBase &&Other)
+    : PDevice(Other.PDevice), PlatformKernelHandle(Other.PlatformKernelHandle),
+      Name(std::move(Other.Name)),
+      DemangledName(std::move(Other.DemangledName)) {
+  Other.PDevice = nullptr;
+  Other.PlatformKernelHandle = nullptr;
+}
+
+KernelBase &KernelBase::operator=(KernelBase &&Other) {
+  PDevice = Other.PDevice;
+  PlatformKernelHandle = Other.PlatformKernelHandle;
+  Name = std::move(Other.Name);
+  DemangledName = std::move(Other.DemangledName);
+  Other.PDevice = nullptr;
+  Other.PlatformKernelHandle = nullptr;
+  return *this;
+}
+
+KernelBase::~KernelBase() {
+  if (PlatformKernelHandle)
+    // TODO(jhen): Handle the error here.
+    consumeError(PDevice->destroyKernel(PlatformKernelHandle));
+}
 
 } // namespace streamexecutor
diff --git a/streamexecutor/lib/PlatformInterfaces.cpp b/streamexecutor/lib/PlatformInterfaces.cpp
index 770cd17..e9378b5 100644
--- a/streamexecutor/lib/PlatformInterfaces.cpp
+++ b/streamexecutor/lib/PlatformInterfaces.cpp
@@ -16,8 +16,6 @@
 
 namespace streamexecutor {
 
-PlatformStreamHandle::~PlatformStreamHandle() = default;
-
 PlatformDevice::~PlatformDevice() = default;
 
 } // namespace streamexecutor
diff --git a/streamexecutor/lib/Stream.cpp b/streamexecutor/lib/Stream.cpp
index e1fca58..96aad04 100644
--- a/streamexecutor/lib/Stream.cpp
+++ b/streamexecutor/lib/Stream.cpp
@@ -12,14 +12,43 @@
 ///
 //===----------------------------------------------------------------------===//
 
+#include <cassert>
+
 #include "streamexecutor/Stream.h"
 
 namespace streamexecutor {
 
-Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream)
-    : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)),
-      ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {}
+Stream::Stream(PlatformDevice *D, const void *PlatformStreamHandle)
+    : PDevice(D), PlatformStreamHandle(PlatformStreamHandle),
+      ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {
+  assert(D != nullptr &&
+         "cannot construct a stream object with a null platform device");
+  assert(PlatformStreamHandle != nullptr &&
+         "cannot construct a stream object with a null platform stream handle");
+}
 
-Stream::~Stream() = default;
+Stream::Stream(Stream &&Other)
+    : PDevice(Other.PDevice), PlatformStreamHandle(Other.PlatformStreamHandle),
+      ErrorMessageMutex(std::move(Other.ErrorMessageMutex)),
+      ErrorMessage(std::move(Other.ErrorMessage)) {
+  Other.PDevice = nullptr;
+  Other.PlatformStreamHandle = nullptr;
+}
+
+Stream &Stream::operator=(Stream &&Other) {
+  PDevice = Other.PDevice;
+  PlatformStreamHandle = Other.PlatformStreamHandle;
+  ErrorMessageMutex = std::move(Other.ErrorMessageMutex);
+  ErrorMessage = std::move(Other.ErrorMessage);
+  Other.PDevice = nullptr;
+  Other.PlatformStreamHandle = nullptr;
+  return *this;
+}
+
+Stream::~Stream() {
+  if (PlatformStreamHandle)
+    // TODO(jhen): Handle error condition here.
+    consumeError(PDevice->destroyStream(PlatformStreamHandle));
+}
 
 } // namespace streamexecutor
diff --git a/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h b/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
index 184c2d7..b54b31d 100644
--- a/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
+++ b/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
@@ -34,9 +34,7 @@
 public:
   std::string getName() const override { return "SimpleHostPlatformDevice"; }
 
-  streamexecutor::Expected<
-      std::unique_ptr<streamexecutor::PlatformStreamHandle>>
-  createStream() override {
+  streamexecutor::Expected<const void *> createStream() override {
     return nullptr;
   }
 
@@ -69,7 +67,7 @@
     return streamexecutor::Error::success();
   }
 
-  streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S,
+  streamexecutor::Error copyD2H(const void *StreamHandle,
                                 const void *DeviceHandleSrc,
                                 size_t SrcByteOffset, void *HostDst,
                                 size_t DstByteOffset,
@@ -80,8 +78,8 @@
     return streamexecutor::Error::success();
   }
 
-  streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S,
-                                const void *HostSrc, size_t SrcByteOffset,
+  streamexecutor::Error copyH2D(const void *StreamHandle, const void *HostSrc,
+                                size_t SrcByteOffset,
                                 const void *DeviceHandleDst,
                                 size_t DstByteOffset,
                                 size_t ByteCount) override {
@@ -92,7 +90,7 @@
   }
 
   streamexecutor::Error
-  copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc,
+  copyD2D(const void *StreamHandle, const void *DeviceHandleSrc,
           size_t SrcByteOffset, const void *DeviceHandleDst,
           size_t DstByteOffset, size_t ByteCount) override {
     std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
diff --git a/streamexecutor/lib/unittests/StreamTest.cpp b/streamexecutor/lib/unittests/StreamTest.cpp
index 4f42bbe..3a0f4e6 100644
--- a/streamexecutor/lib/unittests/StreamTest.cpp
+++ b/streamexecutor/lib/unittests/StreamTest.cpp
@@ -34,11 +34,11 @@
 class StreamTest : public ::testing::Test {
 public:
   StreamTest()
-      : Device(&PDevice),
-        Stream(llvm::make_unique<se::PlatformStreamHandle>(&PDevice)),
-        HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9},
-        HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23},
-        Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35},
+      : DummyPlatformStream(1), Device(&PDevice),
+        Stream(&PDevice, &DummyPlatformStream), HostA5{0, 1, 2, 3, 4},
+        HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16},
+        HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28},
+        Host7{29, 30, 31, 32, 33, 34, 35},
         DeviceA5(getOrDie(Device.allocateDeviceMemory<int>(5))),
         DeviceB5(getOrDie(Device.allocateDeviceMemory<int>(5))),
         DeviceA7(getOrDie(Device.allocateDeviceMemory<int>(7))),
@@ -50,6 +50,8 @@
   }
 
 protected:
+  int DummyPlatformStream; // Mimicking a platform where the platform stream
+                           // handle is just a stream number.
   se::test::SimpleHostPlatformDevice PDevice;
   se::Device Device;
   se::Stream Stream;