[StreamExecutor] Dev handles in platform interface

Summary:
This is the first in a series of patches that will convert
GlobalDeviceMemory to own its device memory handle. The first step is to
remove GlobalDeviceMemoryBase from the PlatformInterface interfaces and
use raw handles there instead. This is useful because
GlobalDeviceMemoryBase is going to lose its importance in this process.

Reviewers: jlebar

Subscribers: jprice, parallel_libs-commits

Differential Revision: https://reviews.llvm.org/D24114

llvm-svn: 280401
GitOrigin-RevId: 8e5b54021ee42418c95397f2adde0ff7e755526b
diff --git a/streamexecutor/include/streamexecutor/Device.h b/streamexecutor/include/streamexecutor/Device.h
index 48ecf22..2493781 100644
--- a/streamexecutor/include/streamexecutor/Device.h
+++ b/streamexecutor/include/streamexecutor/Device.h
@@ -56,16 +56,17 @@
   /// Allocates an array of ElementCount entries of type T in device memory.
   template <typename T>
   Expected<GlobalDeviceMemory<T>> allocateDeviceMemory(size_t ElementCount) {
-    Expected<GlobalDeviceMemoryBase> MaybeBase =
+    Expected<void *> MaybeMemory =
         PDevice->allocateDeviceMemory(ElementCount * sizeof(T));
-    if (!MaybeBase)
-      return MaybeBase.takeError();
-    return GlobalDeviceMemory<T>(*MaybeBase);
+    if (!MaybeMemory)
+      return MaybeMemory.takeError();
+    return GlobalDeviceMemory<T>::makeFromElementCount(*MaybeMemory,
+                                                       ElementCount);
   }
 
   /// Frees memory previously allocated with allocateDeviceMemory.
   template <typename T> Error freeDeviceMemory(GlobalDeviceMemory<T> Memory) {
-    return PDevice->freeDeviceMemory(Memory);
+    return PDevice->freeDeviceMemory(Memory.getHandle());
   }
 
   /// Allocates an array of ElementCount entries of type T in host memory.
@@ -140,7 +141,7 @@
       return make_error(
           "copying too many elements, " + llvm::Twine(ElementCount) +
           ", to a host array of element count " + llvm::Twine(Dst.size()));
-    return PDevice->synchronousCopyD2H(Src.getBaseMemory(),
+    return PDevice->synchronousCopyD2H(Src.getBaseMemory().getHandle(),
                                        Src.getElementOffset() * sizeof(T),
                                        Dst.data(), 0, ElementCount * sizeof(T));
   }
@@ -194,9 +195,9 @@
                         llvm::Twine(ElementCount) +
                         ", to a device array of element count " +
                         llvm::Twine(Dst.getElementCount()));
-    return PDevice->synchronousCopyH2D(Src.data(), 0, Dst.getBaseMemory(),
-                                       Dst.getElementOffset() * sizeof(T),
-                                       ElementCount * sizeof(T));
+    return PDevice->synchronousCopyH2D(
+        Src.data(), 0, Dst.getBaseMemory().getHandle(),
+        Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T));
   }
 
   template <typename T>
@@ -250,8 +251,8 @@
                         ", to a device array of element count " +
                         llvm::Twine(Dst.getElementCount()));
     return PDevice->synchronousCopyD2D(
-        Src.getBaseMemory(), Src.getElementOffset() * sizeof(T),
-        Dst.getBaseMemory(), Dst.getElementOffset() * sizeof(T),
+        Src.getBaseMemory().getHandle(), Src.getElementOffset() * sizeof(T),
+        Dst.getBaseMemory().getHandle(), Dst.getElementOffset() * sizeof(T),
         ElementCount * sizeof(T));
   }
 
diff --git a/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/streamexecutor/include/streamexecutor/PlatformInterfaces.h
index 8fa31b6..b3deff3 100644
--- a/streamexecutor/include/streamexecutor/PlatformInterfaces.h
+++ b/streamexecutor/include/streamexecutor/PlatformInterfaces.h
@@ -94,8 +94,7 @@
   ///
   /// HostDst should have been allocated by allocateHostMemory or registered
   /// with registerHostMemory.
-  virtual Error copyD2H(PlatformStreamHandle *S,
-                        const GlobalDeviceMemoryBase &DeviceSrc,
+  virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle,
                         size_t SrcByteOffset, void *HostDst,
                         size_t DstByteOffset, size_t ByteCount) {
     return make_error("copyD2H not implemented for platform " + getName());
@@ -106,15 +105,14 @@
   /// HostSrc should have been allocated by allocateHostMemory or registered
   /// with registerHostMemory.
   virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc,
-                        size_t SrcByteOffset, GlobalDeviceMemoryBase DeviceDst,
+                        size_t SrcByteOffset, const void *DeviceDstHandle,
                         size_t DstByteOffset, size_t ByteCount) {
     return make_error("copyH2D not implemented for platform " + getName());
   }
 
   /// Copies data from one device location to another.
-  virtual Error copyD2D(PlatformStreamHandle *S,
-                        const GlobalDeviceMemoryBase &DeviceSrc,
-                        size_t SrcByteOffset, GlobalDeviceMemoryBase DeviceDst,
+  virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle,
+                        size_t SrcByteOffset, const void *DeviceDstHandle,
                         size_t DstByteOffset, size_t ByteCount) {
     return make_error("copyD2D not implemented for platform " + getName());
   }
@@ -127,14 +125,13 @@
   }
 
   /// Allocates untyped device memory of a given size in bytes.
-  virtual Expected<GlobalDeviceMemoryBase>
-  allocateDeviceMemory(size_t ByteCount) {
+  virtual Expected<void *> allocateDeviceMemory(size_t ByteCount) {
     return make_error("allocateDeviceMemory not implemented for platform " +
                       getName());
   }
 
   /// Frees device memory previously allocated by allocateDeviceMemory.
-  virtual Error freeDeviceMemory(GlobalDeviceMemoryBase Memory) {
+  virtual Error freeDeviceMemory(const void *Handle) {
     return make_error("freeDeviceMemory not implemented for platform " +
                       getName());
   }
@@ -172,29 +169,29 @@
   /// Blocks the calling host thread until the copy is completed. Can operate on
   /// any host memory, not just registered host memory or host memory allocated
   /// by allocateHostMemory. Does not block any ongoing device calls.
-  virtual Error synchronousCopyD2H(const GlobalDeviceMemoryBase &DeviceSrc,
+  virtual Error synchronousCopyD2H(const void *DeviceSrcHandle,
                                    size_t SrcByteOffset, void *HostDst,
                                    size_t DstByteOffset, size_t ByteCount) {
     return make_error("synchronousCopyD2H not implemented for platform " +
                       getName());
   }
 
-  /// Similar to synchronousCopyD2H(const GlobalDeviceMemoryBase &, size_t, void
+  /// Similar to synchronousCopyD2H(const void *, size_t, void
   /// *, size_t, size_t), but copies memory from host to device rather than
   /// device to host.
   virtual Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset,
-                                   GlobalDeviceMemoryBase DeviceDst,
+                                   const void *DeviceDstHandle,
                                    size_t DstByteOffset, size_t ByteCount) {
     return make_error("synchronousCopyH2D not implemented for platform " +
                       getName());
   }
 
-  /// Similar to synchronousCopyD2H(const GlobalDeviceMemoryBase &, size_t, void
+  /// Similar to synchronousCopyD2H(const void *, size_t, void
   /// *, size_t, size_t), but copies memory from one location in device memory
   /// to another rather than from device to host.
-  virtual Error synchronousCopyD2D(GlobalDeviceMemoryBase DeviceDst,
+  virtual Error synchronousCopyD2D(const void *DeviceDstHandle,
                                    size_t DstByteOffset,
-                                   const GlobalDeviceMemoryBase &DeviceSrc,
+                                   const void *DeviceSrcHandle,
                                    size_t SrcByteOffset, size_t ByteCount) {
     return make_error("synchronousCopyD2D not implemented for platform " +
                       getName());
diff --git a/streamexecutor/include/streamexecutor/Stream.h b/streamexecutor/include/streamexecutor/Stream.h
index 1acb181..054b159 100644
--- a/streamexecutor/include/streamexecutor/Stream.h
+++ b/streamexecutor/include/streamexecutor/Stream.h
@@ -136,7 +136,8 @@
       setError("copying too many elements, " + llvm::Twine(ElementCount) +
                ", to a host array of element count " + llvm::Twine(Dst.size()));
     else
-      setError(PDevice->copyD2H(ThePlatformStream.get(), Src.getBaseMemory(),
+      setError(PDevice->copyD2H(ThePlatformStream.get(),
+                                Src.getBaseMemory().getHandle(),
                                 Src.getElementOffset() * sizeof(T), Dst.data(),
                                 0, ElementCount * sizeof(T)));
     return *this;
@@ -193,9 +194,10 @@
                ", to a device array of element count " +
                llvm::Twine(Dst.getElementCount()));
     else
-      setError(PDevice->copyH2D(
-          ThePlatformStream.get(), Src.data(), 0, Dst.getBaseMemory(),
-          Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
+      setError(PDevice->copyH2D(ThePlatformStream.get(), Src.data(), 0,
+                                Dst.getBaseMemory().getHandle(),
+                                Dst.getElementOffset() * sizeof(T),
+                                ElementCount * sizeof(T)));
     return *this;
   }
 
@@ -250,8 +252,8 @@
                llvm::Twine(Dst.getElementCount()));
     else
       setError(PDevice->copyD2D(
-          ThePlatformStream.get(), Src.getBaseMemory(),
-          Src.getElementOffset() * sizeof(T), Dst.getBaseMemory(),
+          ThePlatformStream.get(), Src.getBaseMemory().getHandle(),
+          Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(),
           Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
     return *this;
   }
diff --git a/streamexecutor/lib/unittests/DeviceTest.cpp b/streamexecutor/lib/unittests/DeviceTest.cpp
index cb34b8b..93d378f 100644
--- a/streamexecutor/lib/unittests/DeviceTest.cpp
+++ b/streamexecutor/lib/unittests/DeviceTest.cpp
@@ -15,6 +15,7 @@
 #include <cstdlib>
 #include <cstring>
 
+#include "SimpleHostPlatformDevice.h"
 #include "streamexecutor/Device.h"
 #include "streamexecutor/PlatformInterfaces.h"
 
@@ -24,79 +25,6 @@
 
 namespace se = ::streamexecutor;
 
-class MockPlatformDevice : public se::PlatformDevice {
-public:
-  ~MockPlatformDevice() override {}
-
-  std::string getName() const override { return "MockPlatformDevice"; }
-
-  se::Expected<std::unique_ptr<se::PlatformStreamHandle>>
-  createStream() override {
-    return se::make_error("not implemented");
-  }
-
-  se::Expected<se::GlobalDeviceMemoryBase>
-  allocateDeviceMemory(size_t ByteCount) override {
-    return se::GlobalDeviceMemoryBase(std::malloc(ByteCount));
-  }
-
-  se::Error freeDeviceMemory(se::GlobalDeviceMemoryBase Memory) override {
-    std::free(const_cast<void *>(Memory.getHandle()));
-    return se::Error::success();
-  }
-
-  se::Expected<void *> allocateHostMemory(size_t ByteCount) override {
-    return std::malloc(ByteCount);
-  }
-
-  se::Error freeHostMemory(void *Memory) override {
-    std::free(Memory);
-    return se::Error::success();
-  }
-
-  se::Error registerHostMemory(void *, size_t) override {
-    return se::Error::success();
-  }
-
-  se::Error unregisterHostMemory(void *) override {
-    return se::Error::success();
-  }
-
-  se::Error synchronousCopyD2H(const se::GlobalDeviceMemoryBase &DeviceSrc,
-                               size_t SrcByteOffset, void *HostDst,
-                               size_t DstByteOffset,
-                               size_t ByteCount) override {
-    std::memcpy(static_cast<char *>(HostDst) + DstByteOffset,
-                static_cast<const char *>(DeviceSrc.getHandle()) +
-                    SrcByteOffset,
-                ByteCount);
-    return se::Error::success();
-  }
-
-  se::Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset,
-                               se::GlobalDeviceMemoryBase DeviceDst,
-                               size_t DstByteOffset,
-                               size_t ByteCount) override {
-    std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
-                    DstByteOffset,
-                static_cast<const char *>(HostSrc) + SrcByteOffset, ByteCount);
-    return se::Error::success();
-  }
-
-  se::Error synchronousCopyD2D(se::GlobalDeviceMemoryBase DeviceDst,
-                               size_t DstByteOffset,
-                               const se::GlobalDeviceMemoryBase &DeviceSrc,
-                               size_t SrcByteOffset,
-                               size_t ByteCount) override {
-    std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
-                    DstByteOffset,
-                static_cast<const char *>(DeviceSrc.getHandle()) +
-                    SrcByteOffset,
-                ByteCount);
-    return se::Error::success();
-  }
-};
-
 /// Test fixture to hold objects used by tests.
 class DeviceTest : public ::testing::Test {
 public:
@@ -124,7 +52,7 @@
   int Host5[5];
   int Host7[7];
 
-  MockPlatformDevice PDevice;
+  SimpleHostPlatformDevice PDevice;
   se::Device Device;
 };
 
diff --git a/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h b/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
new file mode 100644
index 0000000..a2dd3c8
--- /dev/null
+++ b/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
@@ -0,0 +1,135 @@
+//===-- SimpleHostPlatformDevice.h - Host device for testing ----*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// The SimpleHostPlatformDevice class is a streamexecutor::PlatformDevice that
+/// is really just the host processor and memory. It is useful for testing
+/// because no extra device platform is required.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H
+#define STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H
+
+#include <cstdlib>
+#include <cstring>
+
+#include "streamexecutor/PlatformInterfaces.h"
+
+/// A streamexecutor::PlatformDevice that simply forwards all operations to the
+/// host platform.
+///
+/// The allocate and copy methods are simple wrappers for std::malloc and
+/// std::memcpy.
+class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice {
+  std::string getName() const override { return "SimpleHostPlatformDevice"; }
+
+  streamexecutor::Expected<
+      std::unique_ptr<streamexecutor::PlatformStreamHandle>>
+  createStream() override {
+    return nullptr;
+  }
+
+  streamexecutor::Expected<void *>
+  allocateDeviceMemory(size_t ByteCount) override {
+    return std::malloc(ByteCount);
+  }
+
+  streamexecutor::Error freeDeviceMemory(const void *Handle) override {
+    std::free(const_cast<void *>(Handle));
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Expected<void *>
+  allocateHostMemory(size_t ByteCount) override {
+    return std::malloc(ByteCount);
+  }
+
+  streamexecutor::Error freeHostMemory(void *Memory) override {
+    std::free(const_cast<void *>(Memory));
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Error registerHostMemory(void *Memory,
+                                           size_t ByteCount) override {
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Error unregisterHostMemory(void *Memory) override {
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S,
+                                const void *DeviceHandleSrc,
+                                size_t SrcByteOffset, void *HostDst,
+                                size_t DstByteOffset,
+                                size_t ByteCount) override {
+    std::memcpy(static_cast<char *>(HostDst) + DstByteOffset,
+                static_cast<const char *>(DeviceHandleSrc) + SrcByteOffset,
+                ByteCount);
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S,
+                                const void *HostSrc, size_t SrcByteOffset,
+                                const void *DeviceHandleDst,
+                                size_t DstByteOffset,
+                                size_t ByteCount) override {
+    std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
+                    DstByteOffset,
+                static_cast<const char *>(HostSrc) + SrcByteOffset, ByteCount);
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Error
+  copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc,
+          size_t SrcByteOffset, const void *DeviceHandleDst,
+          size_t DstByteOffset, size_t ByteCount) override {
+    std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
+                    DstByteOffset,
+                static_cast<const char *>(DeviceHandleSrc) + SrcByteOffset,
+                ByteCount);
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Error synchronousCopyD2H(const void *DeviceHandleSrc,
+                                           size_t SrcByteOffset, void *HostDst,
+                                           size_t DstByteOffset,
+                                           size_t ByteCount) override {
+    std::memcpy(static_cast<char *>(HostDst) + DstByteOffset,
+                static_cast<const char *>(DeviceHandleSrc) + SrcByteOffset,
+                ByteCount);
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Error synchronousCopyH2D(const void *HostSrc,
+                                           size_t SrcByteOffset,
+                                           const void *DeviceHandleDst,
+                                           size_t DstByteOffset,
+                                           size_t ByteCount) override {
+    std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
+                    DstByteOffset,
+                static_cast<const char *>(HostSrc) + SrcByteOffset, ByteCount);
+    return streamexecutor::Error::success();
+  }
+
+  streamexecutor::Error synchronousCopyD2D(const void *DeviceHandleSrc,
+                                           size_t SrcByteOffset,
+                                           const void *DeviceHandleDst,
+                                           size_t DstByteOffset,
+                                           size_t ByteCount) override {
+    std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
+                    DstByteOffset,
+                static_cast<const char *>(DeviceHandleSrc) + SrcByteOffset,
+                ByteCount);
+    return streamexecutor::Error::success();
+  }
+};
+
+#endif // STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H
diff --git a/streamexecutor/lib/unittests/StreamTest.cpp b/streamexecutor/lib/unittests/StreamTest.cpp
index d05c928..b194bf0 100644
--- a/streamexecutor/lib/unittests/StreamTest.cpp
+++ b/streamexecutor/lib/unittests/StreamTest.cpp
@@ -14,6 +14,7 @@
 
 #include <cstring>
 
+#include "SimpleHostPlatformDevice.h"
 #include "streamexecutor/Device.h"
 #include "streamexecutor/Kernel.h"
 #include "streamexecutor/KernelSpec.h"
@@ -26,52 +27,6 @@
 
 namespace se = ::streamexecutor;
 
-/// Mock PlatformDevice that performs asynchronous memcpy operations by
-/// ignoring the stream argument and calling std::memcpy on device memory
-/// handles.
-class MockPlatformDevice : public se::PlatformDevice {
-public:
-  ~MockPlatformDevice() override {}
-
-  std::string getName() const override { return "MockPlatformDevice"; }
-
-  se::Expected<std::unique_ptr<se::PlatformStreamHandle>>
-  createStream() override {
-    return nullptr;
-  }
-
-  se::Error copyD2H(se::PlatformStreamHandle *S,
-                    const se::GlobalDeviceMemoryBase &DeviceSrc,
-                    size_t SrcByteOffset, void *HostDst, size_t DstByteOffset,
-                    size_t ByteCount) override {
-    std::memcpy(HostDst, static_cast<const char *>(DeviceSrc.getHandle()) +
-                             SrcByteOffset,
-                ByteCount);
-    return se::Error::success();
-  }
-
-  se::Error copyH2D(se::PlatformStreamHandle *S, const void *HostSrc,
-                    size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst,
-                    size_t DstByteOffset, size_t ByteCount) override {
-    std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
-                    DstByteOffset,
-                HostSrc, ByteCount);
-    return se::Error::success();
-  }
-
-  se::Error copyD2D(se::PlatformStreamHandle *S,
-                    const se::GlobalDeviceMemoryBase &DeviceSrc,
-                    size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst,
-                    size_t DstByteOffset, size_t ByteCount) override {
-    std::memcpy(static_cast<char *>(const_cast<void *>(DeviceDst.getHandle())) +
-                    DstByteOffset,
-                static_cast<const char *>(DeviceSrc.getHandle()) +
-                    SrcByteOffset,
-                ByteCount);
-    return se::Error::success();
-  }
-};
-
 /// Test fixture to hold objects used by tests.
 class StreamTest : public ::testing::Test {
 public:
@@ -100,7 +55,7 @@
   int Host5[5];
   int Host7[7];
 
-  MockPlatformDevice PDevice;
+  SimpleHostPlatformDevice PDevice;
   se::Stream Stream;
 };