[Acxxel] Remove setActiveDeviceForThread

Summary:
After experimenting with CUDA, I realized that we really only need to
set the active context right before creating an object such as a stream
or a device memory allocation. When we go on to use these objects later,
it is fine if the context that created them is no longer active,
operations with those objects will succeed anyway.

Since it turns out that we don't have to check the active context for
every operation, it makes sense to hide this active context from users
(by removing the "ActiveDeviceForThread" setter and getter) and to
change the Acxxel API to explicitly pass in the device ID to create
objects.

This change improves the Acxxel API and greatly simplifies the CUDA and
OpenCL implementations because they no longer require thread_local data.

Reviewers: jlebar, jprice

Subscribers: mgorny, parallel_libs-commits

Differential Revision: https://reviews.llvm.org/D26050

llvm-svn: 285372
GitOrigin-RevId: bdc410babaee93aad9e7b582e49ba35a774e3825
diff --git a/acxxel/CMakeLists.txt b/acxxel/CMakeLists.txt
index a0c017e..4b21da3 100644
--- a/acxxel/CMakeLists.txt
+++ b/acxxel/CMakeLists.txt
@@ -1,6 +1,7 @@
 cmake_minimum_required(VERSION 3.1)
 
 option(ACXXEL_ENABLE_UNIT_TESTS "enable acxxel unit tests" ON)
+option(ACXXEL_ENABLE_MULTI_DEVICE_UNIT_TESTS "enable acxxel multi-device unit tests" OFF)
 option(ACXXEL_ENABLE_EXAMPLES "enable acxxel examples" OFF)
 option(ACXXEL_ENABLE_DOXYGEN "enable Doxygen for acxxel" OFF)
 option(ACXXEL_ENABLE_CUDA "enable CUDA for acxxel" ON)
diff --git a/acxxel/acxxel.h b/acxxel/acxxel.h
index 8cfc5c5..c98941a 100644
--- a/acxxel/acxxel.h
+++ b/acxxel/acxxel.h
@@ -229,12 +229,7 @@
 /// All operations enqueued on a Stream are serialized, but operations enqueued
 /// on different Streams may run concurrently.
 ///
-/// Each Platform has a notion of the currently active device on a particular
-/// thread (see Platform::getActiveDeviceForThread and
-/// Platform::setActiveDeviceForThread). Each Stream is associated with a
-/// specific, fixed device, set to the current thread's active device when the
-/// Stream is created. Whenver a thread enqueues commands onto a Stream, its
-/// active device must match the Stream's device.
+/// Each Stream is associated with a specific, fixed device.
 class Stream {
 public:
   Stream(const Stream &) = delete;
@@ -447,10 +442,16 @@
 private:
   // Only a platform can make an event.
   friend class Platform;
-  Event(Platform *APlatform, void *AHandle, HandleDestructor Destructor)
-      : ThePlatform(APlatform), TheHandle(AHandle, Destructor) {}
+  Event(Platform *APlatform, int DeviceIndex, void *AHandle,
+        HandleDestructor Destructor)
+      : ThePlatform(APlatform), TheDeviceIndex(DeviceIndex),
+        TheHandle(AHandle, Destructor) {}
 
   Platform *ThePlatform;
+
+  // The index of the device on which the event can be enqueued.
+  int TheDeviceIndex;
+
   std::unique_ptr<void, HandleDestructor> TheHandle;
 };
 
@@ -470,29 +471,21 @@
   /// Gets the number of devices for this platform in this system.
   virtual Expected<int> getDeviceCount() = 0;
 
-  /// Sets the active device for this platform in this thread.
-  virtual Status setActiveDeviceForThread(int DeviceIndex) = 0;
+  /// Creates a stream on the given device for the platform.
+  virtual Expected<Stream> createStream(int DeviceIndex = 0) = 0;
 
-  /// Gets the currently active device for this platform in this thread.
-  virtual int getActiveDeviceForThread() = 0;
-
-  /// Creates a stream for the platform.
-  ///
-  /// The created Stream is associated with the active device for this thread.
-  virtual Expected<Stream> createStream() = 0;
-
-  /// Creates an event for the platform.
-  ///
-  /// The created Event is associated with the active device for this thread.
-  virtual Expected<Event> createEvent() = 0;
+  /// Creates an event on the given device for the platform.
+  virtual Expected<Event> createEvent(int DeviceIndex = 0) = 0;
 
   /// Allocates owned device memory.
   ///
   /// \warning This function only allocates space in device memory, it does not
   /// call the constructor of T.
   template <typename T>
-  Expected<DeviceMemory<T>> mallocD(ptrdiff_t ElementCount) {
-    Expected<void *> MaybePointer = rawMallocD(ElementCount * sizeof(T));
+  Expected<DeviceMemory<T>> mallocD(ptrdiff_t ElementCount,
+                                    int DeviceIndex = 0) {
+    Expected<void *> MaybePointer =
+        rawMallocD(ElementCount * sizeof(T), DeviceIndex);
     if (MaybePointer.isError())
       return MaybePointer.getError();
     return DeviceMemory<T>(this, MaybePointer.getValue(), ElementCount,
@@ -505,12 +498,14 @@
   /// pointer to a __device__ variable, this function returns a DeviceMemorySpan
   /// referencing the device memory that stores that __device__ variable.
   template <typename ElementType>
-  Expected<DeviceMemorySpan<ElementType>> getSymbolMemory(ElementType *Symbol) {
-    Expected<void *> MaybeAddress = rawGetDeviceSymbolAddress(Symbol);
+  Expected<DeviceMemorySpan<ElementType>> getSymbolMemory(ElementType *Symbol,
+                                                          int DeviceIndex = 0) {
+    Expected<void *> MaybeAddress =
+        rawGetDeviceSymbolAddress(Symbol, DeviceIndex);
     if (MaybeAddress.isError())
       return MaybeAddress.getError();
     ElementType *Address = static_cast<ElementType *>(MaybeAddress.getValue());
-    Expected<ptrdiff_t> MaybeSize = rawGetDeviceSymbolSize(Symbol);
+    Expected<ptrdiff_t> MaybeSize = rawGetDeviceSymbolSize(Symbol, DeviceIndex);
     if (MaybeSize.isError())
       return MaybeSize.getError();
     ptrdiff_t Size = MaybeSize.getValue();
@@ -584,8 +579,8 @@
 
   /// \}
 
-  virtual Expected<Program>
-  createProgramFromSource(Span<const char> Source) = 0;
+  virtual Expected<Program> createProgramFromSource(Span<const char> Source,
+                                                    int DeviceIndex = 0) = 0;
 
 protected:
   friend class Stream;
@@ -597,15 +592,15 @@
   void *getEventHandle(Event &Event) { return Event.TheHandle.get(); }
 
   // Pass along access to Stream constructor to subclasses.
-  Stream constructStream(Platform *APlatform, void *AHandle,
+  Stream constructStream(Platform *APlatform, int DeviceIndex, void *AHandle,
                          HandleDestructor Destructor) {
-    return Stream(APlatform, getActiveDeviceForThread(), AHandle, Destructor);
+    return Stream(APlatform, DeviceIndex, AHandle, Destructor);
   }
 
   // Pass along access to Event constructor to subclasses.
-  Event constructEvent(Platform *APlatform, void *AHandle,
+  Event constructEvent(Platform *APlatform, int DeviceIndex, void *AHandle,
                        HandleDestructor Destructor) {
-    return Event(APlatform, AHandle, Destructor);
+    return Event(APlatform, DeviceIndex, AHandle, Destructor);
   }
 
   // Pass along access to Program constructor to subclasses.
@@ -623,28 +618,16 @@
   virtual Expected<float> getSecondsBetweenEvents(void *StartEvent,
                                                   void *EndEvent) = 0;
 
-  virtual Expected<void *> rawMallocD(ptrdiff_t ByteCount) = 0;
+  virtual Expected<void *> rawMallocD(ptrdiff_t ByteCount, int DeviceIndex) = 0;
   virtual HandleDestructor getDeviceMemoryHandleDestructor() = 0;
   virtual void *getDeviceMemorySpanHandle(void *BaseHandle, size_t ByteSize,
                                           size_t ByteOffset) = 0;
   virtual void rawDestroyDeviceMemorySpanHandle(void *Handle) = 0;
 
-  virtual Expected<void *> rawGetDeviceSymbolAddress(const void *Symbol) = 0;
-  virtual Expected<ptrdiff_t> rawGetDeviceSymbolSize(const void *Symbol) = 0;
-
-  virtual Status rawCopyDToD(const void *DeviceSrc,
-                             ptrdiff_t DeviceSrcByteOffset, void *DeviceDst,
-                             ptrdiff_t DeviceDstByteOffset,
-                             ptrdiff_t ByteCount) = 0;
-  virtual Status rawCopyDToH(const void *DeviceSrc,
-                             ptrdiff_t DeviceSrcByteOffset, void *HostDst,
-                             ptrdiff_t ByteCount) = 0;
-  virtual Status rawCopyHToD(const void *HostSrc, void *DeviceDst,
-                             ptrdiff_t DeviceDstByteOffset,
-                             ptrdiff_t ByteCount) = 0;
-
-  virtual Status rawMemsetD(void *DeviceDst, ptrdiff_t ByteOffset,
-                            ptrdiff_t ByteCount, char ByteValue) = 0;
+  virtual Expected<void *> rawGetDeviceSymbolAddress(const void *Symbol,
+                                                     int DeviceIndex) = 0;
+  virtual Expected<ptrdiff_t> rawGetDeviceSymbolSize(const void *Symbol,
+                                                     int DeviceIndex) = 0;
 
   virtual Status rawRegisterHostMem(const void *Memory,
                                     ptrdiff_t ByteCount) = 0;
diff --git a/acxxel/cuda_acxxel.cpp b/acxxel/cuda_acxxel.cpp
index 9052b8d..d8ec44c 100644
--- a/acxxel/cuda_acxxel.cpp
+++ b/acxxel/cuda_acxxel.cpp
@@ -25,9 +25,6 @@
 
 namespace {
 
-/// Index of active device for this thread.
-thread_local int ActiveDeviceIndex = 0;
-
 static std::string getCUErrorMessage(CUresult Result) {
   if (!Result)
     return "success";
@@ -85,39 +82,25 @@
 
   Expected<int> getDeviceCount() override;
 
-  Status setActiveDeviceForThread(int DeviceIndex) override;
-
-  int getActiveDeviceForThread() override;
-
-  Expected<Stream> createStream() override;
+  Expected<Stream> createStream(int DeviceIndex) override;
 
   Status streamSync(void *Stream) override;
 
   Status streamWaitOnEvent(void *Stream, void *Event) override;
 
-  Expected<Event> createEvent() override;
+  Expected<Event> createEvent(int DeviceIndex) override;
 
 protected:
-  Expected<void *> rawMallocD(ptrdiff_t ByteCount) override;
+  Expected<void *> rawMallocD(ptrdiff_t ByteCount, int DeviceIndex) override;
   HandleDestructor getDeviceMemoryHandleDestructor() override;
   void *getDeviceMemorySpanHandle(void *BaseHandle, size_t ByteSize,
                                   size_t ByteOffset) override;
   virtual void rawDestroyDeviceMemorySpanHandle(void *Handle) override;
 
-  Expected<void *> rawGetDeviceSymbolAddress(const void *Symbol) override;
-  Expected<ptrdiff_t> rawGetDeviceSymbolSize(const void *Symbol) override;
-
-  Status rawCopyDToD(const void *DeviceSrc, ptrdiff_t DeviceSrcByteOffset,
-                     void *DeviceDst, ptrdiff_t DeviceDstByteOffset,
-                     ptrdiff_t ByteCount) override;
-  Status rawCopyDToH(const void *DeviceSrc, ptrdiff_t DeviceSrcByteOffset,
-                     void *HostDst, ptrdiff_t ByteCount) override;
-  Status rawCopyHToD(const void *HostSrc, void *DeviceDst,
-                     ptrdiff_t DeviceDstByteOffset,
-                     ptrdiff_t ByteCount) override;
-
-  Status rawMemsetD(void *DeviceDst, ptrdiff_t ByteOffset, ptrdiff_t ByteCount,
-                    char ByteValue) override;
+  Expected<void *> rawGetDeviceSymbolAddress(const void *Symbol,
+                                             int DeviceIndex) override;
+  Expected<ptrdiff_t> rawGetDeviceSymbolSize(const void *Symbol,
+                                             int DeviceIndex) override;
 
   Status rawRegisterHostMem(const void *Memory, ptrdiff_t ByteCount) override;
   HandleDestructor getUnregisterHostMemoryHandleDestructor() override;
@@ -141,7 +124,8 @@
 
   Status addStreamCallback(Stream &Stream, StreamCallback Callback) override;
 
-  Expected<Program> createProgramFromSource(Span<const char> Source) override;
+  Expected<Program> createProgramFromSource(Span<const char> Source,
+                                            int DeviceIndex) override;
 
   Status enqueueEvent(void *Event, void *Stream) override;
   bool eventIsDone(void *Event) override;
@@ -163,6 +147,14 @@
   explicit CUDAPlatform(const std::vector<CUcontext> &Contexts)
       : TheContexts(Contexts) {}
 
+  Status setContext(int DeviceIndex) {
+    if (DeviceIndex < 0 ||
+        static_cast<size_t>(DeviceIndex) >= TheContexts.size())
+      return Status("invalid deivce index " + std::to_string(DeviceIndex));
+    return getCUError(cuCtxSetCurrent(TheContexts[DeviceIndex]),
+                      "cuCtxSetCurrent");
+  }
+
   // Vector of contexts for each device.
   std::vector<CUcontext> TheContexts;
 };
@@ -191,17 +183,6 @@
   return CUDAPlatform(Contexts);
 }
 
-Status CUDAPlatform::setActiveDeviceForThread(int DeviceIndex) {
-  if (static_cast<size_t>(DeviceIndex) >= TheContexts.size())
-    return Status("invalid device index for SetActiveDevice: " +
-                  std::to_string(DeviceIndex));
-  ActiveDeviceIndex = DeviceIndex;
-  return getCUError(cuCtxSetCurrent(TheContexts[DeviceIndex]),
-                    "setActiveDeviceForThread cuCtxSetCurrent");
-}
-
-int CUDAPlatform::getActiveDeviceForThread() { return ActiveDeviceIndex; }
-
 Expected<int> CUDAPlatform::getDeviceCount() {
   int Count = 0;
   if (CUresult Result = cuDeviceGetCount(&Count))
@@ -214,12 +195,15 @@
                "cuStreamDestroy");
 }
 
-Expected<Stream> CUDAPlatform::createStream() {
+Expected<Stream> CUDAPlatform::createStream(int DeviceIndex) {
+  Status S = setContext(DeviceIndex);
+  if (S.isError())
+    return S;
   unsigned int Flags = CU_STREAM_DEFAULT;
   CUstream Handle;
   if (CUresult Result = cuStreamCreate(&Handle, Flags))
     return getCUError(Result, "cuStreamCreate");
-  return constructStream(this, Handle, cudaDestroyStream);
+  return constructStream(this, DeviceIndex, Handle, cudaDestroyStream);
 }
 
 Status CUDAPlatform::streamSync(void *Stream) {
@@ -239,12 +223,15 @@
   logCUWarning(cuEventDestroy(static_cast<CUevent_st *>(H)), "cuEventDestroy");
 }
 
-Expected<Event> CUDAPlatform::createEvent() {
+Expected<Event> CUDAPlatform::createEvent(int DeviceIndex) {
+  Status S = setContext(DeviceIndex);
+  if (S.isError())
+    return S;
   unsigned int Flags = CU_EVENT_DEFAULT;
   CUevent Handle;
   if (CUresult Result = cuEventCreate(&Handle, Flags))
     return getCUError(Result, "cuEventCreate");
-  return constructEvent(this, Handle, cudaDestroyEvent);
+  return constructEvent(this, DeviceIndex, Handle, cudaDestroyEvent);
 }
 
 Status CUDAPlatform::enqueueEvent(void *Event, void *Stream) {
@@ -272,7 +259,11 @@
   return Milliseconds * 1e-6;
 }
 
-Expected<void *> CUDAPlatform::rawMallocD(ptrdiff_t ByteCount) {
+Expected<void *> CUDAPlatform::rawMallocD(ptrdiff_t ByteCount,
+                                          int DeviceIndex) {
+  Status S = setContext(DeviceIndex);
+  if (S.isError())
+    return S;
   if (!ByteCount)
     return nullptr;
   CUdeviceptr Pointer;
@@ -298,14 +289,22 @@
   // Do nothing for this platform.
 }
 
-Expected<void *> CUDAPlatform::rawGetDeviceSymbolAddress(const void *Symbol) {
+Expected<void *> CUDAPlatform::rawGetDeviceSymbolAddress(const void *Symbol,
+                                                         int DeviceIndex) {
+  Status S = setContext(DeviceIndex);
+  if (S.isError())
+    return S;
   void *Address;
   if (cudaError_t Status = cudaGetSymbolAddress(&Address, Symbol))
     return getCUDAError(Status, "cudaGetSymbolAddress");
   return Address;
 }
 
-Expected<ptrdiff_t> CUDAPlatform::rawGetDeviceSymbolSize(const void *Symbol) {
+Expected<ptrdiff_t> CUDAPlatform::rawGetDeviceSymbolSize(const void *Symbol,
+                                                         int DeviceIndex) {
+  Status S = setContext(DeviceIndex);
+  if (S.isError())
+    return S;
   size_t Size;
   if (cudaError_t Status = cudaGetSymbolSize(&Size, Symbol))
     return getCUDAError(Status, "cudaGetSymbolSize");
@@ -320,45 +319,6 @@
   return static_cast<void *>(static_cast<char *>(Ptr) + ByteOffset);
 }
 
-Status CUDAPlatform::rawCopyDToD(const void *DeviceSrc,
-                                 ptrdiff_t DeviceSrcByteOffset, void *DeviceDst,
-                                 ptrdiff_t DeviceDstByteOffset,
-                                 ptrdiff_t ByteCount) {
-  return getCUError(cuMemcpyDtoD(reinterpret_cast<CUdeviceptr>(offsetVoidPtr(
-                                     DeviceDst, DeviceDstByteOffset)),
-                                 reinterpret_cast<CUdeviceptr>(offsetVoidPtr(
-                                     DeviceSrc, DeviceSrcByteOffset)),
-                                 ByteCount),
-                    "cuMemcpyDtoD");
-}
-
-Status CUDAPlatform::rawCopyDToH(const void *DeviceSrc,
-                                 ptrdiff_t DeviceSrcByteOffset, void *HostDst,
-                                 ptrdiff_t ByteCount) {
-  return getCUError(
-      cuMemcpyDtoH(HostDst, reinterpret_cast<CUdeviceptr>(
-                                offsetVoidPtr(DeviceSrc, DeviceSrcByteOffset)),
-                   ByteCount),
-      "cuMemcpyDtoH");
-}
-
-Status CUDAPlatform::rawCopyHToD(const void *HostSrc, void *DeviceDst,
-                                 ptrdiff_t DeviceDstByteOffset,
-                                 ptrdiff_t ByteCount) {
-  return getCUError(cuMemcpyHtoD(reinterpret_cast<CUdeviceptr>(offsetVoidPtr(
-                                     DeviceDst, DeviceDstByteOffset)),
-                                 HostSrc, ByteCount),
-                    "cuMemcpyHtoD");
-}
-
-Status CUDAPlatform::rawMemsetD(void *DeviceDst, ptrdiff_t ByteOffset,
-                                ptrdiff_t ByteCount, char ByteValue) {
-  return getCUError(cuMemsetD8(reinterpret_cast<CUdeviceptr>(
-                                   offsetVoidPtr(DeviceDst, ByteOffset)),
-                               ByteValue, ByteCount),
-                    "cuMemsetD8");
-}
-
 Status CUDAPlatform::rawRegisterHostMem(const void *Memory,
                                         ptrdiff_t ByteCount) {
   unsigned int Flags = 0;
@@ -468,8 +428,11 @@
   logCUWarning(cuModuleUnload(static_cast<CUmod_st *>(H)), "cuModuleUnload");
 }
 
-Expected<Program>
-CUDAPlatform::createProgramFromSource(Span<const char> Source) {
+Expected<Program> CUDAPlatform::createProgramFromSource(Span<const char> Source,
+                                                        int DeviceIndex) {
+  Status S = setContext(DeviceIndex);
+  if (S.isError())
+    return S;
   CUmodule Module;
   constexpr int LogBufferSizeBytes = 1024;
   char InfoLogBuffer[LogBufferSizeBytes];
diff --git a/acxxel/opencl_acxxel.cpp b/acxxel/opencl_acxxel.cpp
index 2ca74ed..0c2d9b6 100644
--- a/acxxel/opencl_acxxel.cpp
+++ b/acxxel/opencl_acxxel.cpp
@@ -33,8 +33,6 @@
       : PlatformID(PlatformID), DeviceID(DeviceID) {}
 };
 
-thread_local int ActiveDeviceIndex = 0;
-
 static std::string getOpenCLErrorMessage(cl_int Result) {
   if (!Result)
     return "success";
@@ -67,41 +65,28 @@
 
   Expected<int> getDeviceCount() override;
 
-  Status setActiveDeviceForThread(int DeviceIndex) override;
+  Expected<Stream> createStream(int DeviceIndex) override;
 
-  int getActiveDeviceForThread() override;
+  Expected<Event> createEvent(int DeviceIndex) override;
 
-  Expected<Stream> createStream() override;
-
-  Expected<Event> createEvent() override;
-
-  Expected<Program> createProgramFromSource(Span<const char> Source) override;
+  Expected<Program> createProgramFromSource(Span<const char> Source,
+                                            int DeviceIndex) override;
 
 protected:
   Status streamSync(void *Stream) override;
 
   Status streamWaitOnEvent(void *Stream, void *Event) override;
 
-  Expected<void *> rawMallocD(ptrdiff_t ByteCount) override;
+  Expected<void *> rawMallocD(ptrdiff_t ByteCount, int DeviceIndex) override;
   HandleDestructor getDeviceMemoryHandleDestructor() override;
   void *getDeviceMemorySpanHandle(void *BaseHandle, size_t ByteSize,
                                   size_t ByteOffset) override;
   void rawDestroyDeviceMemorySpanHandle(void *Handle) override;
 
-  Expected<void *> rawGetDeviceSymbolAddress(const void *Symbol) override;
-  Expected<ptrdiff_t> rawGetDeviceSymbolSize(const void *Symbol) override;
-
-  Status rawCopyDToD(const void *DeviceSrc, ptrdiff_t DeviceSrcByteOffset,
-                     void *DeviceDst, ptrdiff_t DeviceDstByteOffset,
-                     ptrdiff_t ByteCount) override;
-  Status rawCopyDToH(const void *DeviceSrc, ptrdiff_t DeviceSrcByteOffset,
-                     void *HostDst, ptrdiff_t ByteCount) override;
-  Status rawCopyHToD(const void *HostSrc, void *DeviceDst,
-                     ptrdiff_t DeviceDstByteOffset,
-                     ptrdiff_t ByteCount) override;
-
-  Status rawMemsetD(void *DeviceDst, ptrdiff_t ByteOffset, ptrdiff_t ByteCount,
-                    char ByteValue) override;
+  Expected<void *> rawGetDeviceSymbolAddress(const void *Symbol,
+                                             int DeviceIndex) override;
+  Expected<ptrdiff_t> rawGetDeviceSymbolSize(const void *Symbol,
+                                             int DeviceIndex) override;
 
   Status rawRegisterHostMem(const void *Memory, ptrdiff_t ByteCount) override;
   HandleDestructor getUnregisterHostMemoryHandleDestructor() override;
@@ -200,31 +185,19 @@
 
 Expected<int> OpenCLPlatform::getDeviceCount() { return FullDeviceIDs.size(); }
 
-Status OpenCLPlatform::setActiveDeviceForThread(int DeviceIndex) {
-  if (static_cast<size_t>(DeviceIndex) >= FullDeviceIDs.size())
-    return Status("Could not set active device index to " +
-                  std::to_string(DeviceIndex) + " because there are only " +
-                  std::to_string(FullDeviceIDs.size()) +
-                  " devices in the system");
-  ActiveDeviceIndex = DeviceIndex;
-  return Status();
-}
-
-int OpenCLPlatform::getActiveDeviceForThread() { return ActiveDeviceIndex; }
-
 static void openCLDestroyStream(void *H) {
   logOpenCLWarning(clReleaseCommandQueue(static_cast<cl_command_queue>(H)),
                    "clReleaseCommandQueue");
 }
 
-Expected<Stream> OpenCLPlatform::createStream() {
+Expected<Stream> OpenCLPlatform::createStream(int DeviceIndex) {
   cl_int Result;
   cl_command_queue Queue = clCreateCommandQueue(
-      Contexts[ActiveDeviceIndex], FullDeviceIDs[ActiveDeviceIndex].DeviceID,
+      Contexts[DeviceIndex], FullDeviceIDs[DeviceIndex].DeviceID,
       CL_QUEUE_PROFILING_ENABLE, &Result);
   if (Result)
     return getOpenCLError(Result, "clCreateCommandQueue");
-  return constructStream(this, Queue, openCLDestroyStream);
+  return constructStream(this, DeviceIndex, Queue, openCLDestroyStream);
 }
 
 static void openCLEventDestroy(void *H) {
@@ -246,14 +219,15 @@
       "clEnqueueMarkerWithWaitList");
 }
 
-Expected<Event> OpenCLPlatform::createEvent() {
+Expected<Event> OpenCLPlatform::createEvent(int DeviceIndex) {
   cl_int Result;
-  cl_event Event = clCreateUserEvent(Contexts[ActiveDeviceIndex], &Result);
+  cl_event Event = clCreateUserEvent(Contexts[DeviceIndex], &Result);
   if (Result)
     return getOpenCLError(Result, "clCreateUserEvent");
   if (cl_int Result = clSetUserEventStatus(Event, CL_COMPLETE))
     return getOpenCLError(Result, "clSetUserEventStatus");
-  return constructEvent(this, new cl_event(Event), openCLEventDestroy);
+  return constructEvent(this, DeviceIndex, new cl_event(Event),
+                        openCLEventDestroy);
 }
 
 static void openCLDestroyProgram(void *H) {
@@ -262,24 +236,26 @@
 }
 
 Expected<Program>
-OpenCLPlatform::createProgramFromSource(Span<const char> Source) {
+OpenCLPlatform::createProgramFromSource(Span<const char> Source,
+                                        int DeviceIndex) {
   cl_int Error;
   const char *CSource = Source.data();
   size_t SourceSize = Source.size();
-  cl_program Program = clCreateProgramWithSource(Contexts[ActiveDeviceIndex], 1,
+  cl_program Program = clCreateProgramWithSource(Contexts[DeviceIndex], 1,
                                                  &CSource, &SourceSize, &Error);
   if (Error)
     return getOpenCLError(Error, "clCreateProgramWithSource");
-  cl_device_id DeviceID = FullDeviceIDs[ActiveDeviceIndex].DeviceID;
+  cl_device_id DeviceID = FullDeviceIDs[DeviceIndex].DeviceID;
   if (cl_int Error =
           clBuildProgram(Program, 1, &DeviceID, nullptr, nullptr, nullptr))
     return getOpenCLError(Error, "clBuildProgram");
   return constructProgram(this, Program, openCLDestroyProgram);
 }
 
-Expected<void *> OpenCLPlatform::rawMallocD(ptrdiff_t ByteCount) {
+Expected<void *> OpenCLPlatform::rawMallocD(ptrdiff_t ByteCount,
+                                            int DeviceIndex) {
   cl_int Result;
-  cl_mem Memory = clCreateBuffer(Contexts[ActiveDeviceIndex], CL_MEM_READ_WRITE,
+  cl_mem Memory = clCreateBuffer(Contexts[DeviceIndex], CL_MEM_READ_WRITE,
                                  ByteCount, nullptr, &Result);
   if (Result)
     return getOpenCLError(Result, "clCreateBuffer");
@@ -316,66 +292,19 @@
 }
 
 Expected<void *>
-OpenCLPlatform::rawGetDeviceSymbolAddress(const void * /*Symbol*/) {
+OpenCLPlatform::rawGetDeviceSymbolAddress(const void * /*Symbol*/,
+                                          int /*DeviceIndex*/) {
   // This doesn't seem to have any equivalent in OpenCL.
   return Status("not implemented");
 }
 
 Expected<ptrdiff_t>
-OpenCLPlatform::rawGetDeviceSymbolSize(const void * /*Symbol*/) {
+OpenCLPlatform::rawGetDeviceSymbolSize(const void * /*Symbol*/,
+                                       int /*DeviceIndex*/) {
   // This doesn't seem to have any equivalent in OpenCL.
   return Status("not implemented");
 }
 
-Status OpenCLPlatform::rawCopyDToD(const void *DeviceSrc,
-                                   ptrdiff_t DeviceSrcByteOffset,
-                                   void *DeviceDst,
-                                   ptrdiff_t DeviceDstByteOffset,
-                                   ptrdiff_t ByteCount) {
-  cl_event DoneEvent;
-  if (cl_int Result = clEnqueueCopyBuffer(
-          CommandQueues[ActiveDeviceIndex],
-          static_cast<cl_mem>(const_cast<void *>(DeviceSrc)),
-          static_cast<cl_mem>(DeviceDst), DeviceSrcByteOffset,
-          DeviceDstByteOffset, ByteCount, 0, nullptr, &DoneEvent))
-    return getOpenCLError(Result, "clEnqueueCopyBuffer");
-  return getOpenCLError(clWaitForEvents(1, &DoneEvent), "clWaitForEvents");
-}
-
-Status OpenCLPlatform::rawCopyDToH(const void *DeviceSrc,
-                                   ptrdiff_t DeviceSrcByteOffset, void *HostDst,
-                                   ptrdiff_t ByteCount) {
-  cl_event DoneEvent;
-  if (cl_int Result = clEnqueueReadBuffer(
-          CommandQueues[ActiveDeviceIndex],
-          static_cast<cl_mem>(const_cast<void *>(DeviceSrc)), CL_TRUE,
-          DeviceSrcByteOffset, ByteCount, HostDst, 0, nullptr, &DoneEvent))
-    return getOpenCLError(Result, "clEnqueueReadBuffer");
-  return getOpenCLError(clWaitForEvents(1, &DoneEvent), "clWaitForEvents");
-}
-
-Status OpenCLPlatform::rawCopyHToD(const void *HostSrc, void *DeviceDst,
-                                   ptrdiff_t DeviceDstByteOffset,
-                                   ptrdiff_t ByteCount) {
-  cl_event DoneEvent;
-  if (cl_int Result = clEnqueueWriteBuffer(
-          CommandQueues[ActiveDeviceIndex], static_cast<cl_mem>(DeviceDst),
-          CL_TRUE, DeviceDstByteOffset, ByteCount, HostSrc, 0, nullptr,
-          &DoneEvent))
-    return getOpenCLError(Result, "clEnqueueWriteBuffer");
-  return getOpenCLError(clWaitForEvents(1, &DoneEvent), "clWaitForEvents");
-}
-
-Status OpenCLPlatform::rawMemsetD(void *DeviceDst, ptrdiff_t ByteOffset,
-                                  ptrdiff_t ByteCount, char ByteValue) {
-  cl_event DoneEvent;
-  if (cl_int Result = clEnqueueFillBuffer(
-          CommandQueues[ActiveDeviceIndex], static_cast<cl_mem>(DeviceDst),
-          &ByteValue, 1, ByteOffset, ByteCount, 0, nullptr, &DoneEvent))
-    return getOpenCLError(Result, "clEnqueueFillBuffer");
-  return getOpenCLError(clWaitForEvents(1, &DoneEvent), "clWaitForEvents");
-}
-
 static void noOpHandleDestructor(void *) {}
 
 Status OpenCLPlatform::rawRegisterHostMem(const void * /*Memory*/,
@@ -478,10 +407,12 @@
 Status OpenCLPlatform::addStreamCallback(Stream &TheStream,
                                          StreamCallback Callback) {
   cl_int Result;
-  cl_event StartEvent = clCreateUserEvent(Contexts[ActiveDeviceIndex], &Result);
+  cl_event StartEvent =
+      clCreateUserEvent(Contexts[TheStream.getDeviceIndex()], &Result);
   if (Result)
     return getOpenCLError(Result, "clCreateUserEvent");
-  cl_event EndEvent = clCreateUserEvent(Contexts[ActiveDeviceIndex], &Result);
+  cl_event EndEvent =
+      clCreateUserEvent(Contexts[TheStream.getDeviceIndex()], &Result);
   if (Result)
     return getOpenCLError(Result, "clCreateUserEvent");
   cl_event StartBarrierEvent;
diff --git a/acxxel/tests/CMakeLists.txt b/acxxel/tests/CMakeLists.txt
index 17c2df4..d971e49 100644
--- a/acxxel/tests/CMakeLists.txt
+++ b/acxxel/tests/CMakeLists.txt
@@ -29,3 +29,13 @@
     ${CMAKE_THREAD_LIBS_INIT})
 add_test(OpenCLTest opencl_test)
 endif()
+
+if(ACXXEL_ENABLE_MULTI_DEVICE_UNIT_TESTS)
+add_executable(multi_device_test multi_device_test.cpp)
+target_link_libraries(
+    multi_device_test
+    acxxel
+    ${GTEST_BOTH_LIBRARIES}
+    ${CMAKE_THREAD_LIBS_INIT})
+add_test(MultiDeviceTest multi_device_test)
+endif()
diff --git a/acxxel/tests/acxxel_test.cpp b/acxxel/tests/acxxel_test.cpp
index f0bc89c..b7bb3b4 100644
--- a/acxxel/tests/acxxel_test.cpp
+++ b/acxxel/tests/acxxel_test.cpp
@@ -18,7 +18,9 @@
 
 namespace {
 
-template <typename T, size_t N> constexpr size_t size(T (&)[N]) { return N; }
+template <typename T, size_t N> constexpr size_t arraySize(T (&)[N]) {
+  return N;
+}
 
 using PlatformGetter = acxxel::Expected<acxxel::Platform *> (*)();
 class AcxxelTest : public ::testing::TestWithParam<PlatformGetter> {};
@@ -167,11 +169,12 @@
   acxxel::Platform *Platform = GetParam()().takeValue();
   acxxel::Stream Stream = Platform->createStream().takeValue();
   int A[] = {0, 1, 2};
-  std::array<int, size(A)> B;
-  acxxel::DeviceMemory<int> X = Platform->mallocD<int>(size(A)).takeValue();
+  std::array<int, arraySize(A)> B;
+  acxxel::DeviceMemory<int> X =
+      Platform->mallocD<int>(arraySize(A)).takeValue();
   Stream.syncCopyHToD(A, X);
   Stream.syncCopyDToH(X, B);
-  for (size_t I = 0; I < size(A); ++I)
+  for (size_t I = 0; I < arraySize(A); ++I)
     EXPECT_EQ(A[I], B[I]);
   EXPECT_FALSE(Stream.takeStatus().isError());
 }
@@ -180,13 +183,15 @@
   acxxel::Platform *Platform = GetParam()().takeValue();
   acxxel::Stream Stream = Platform->createStream().takeValue();
   int A[] = {0, 1, 2};
-  std::array<int, size(A)> B;
-  acxxel::DeviceMemory<int> X = Platform->mallocD<int>(size(A)).takeValue();
-  acxxel::DeviceMemory<int> Y = Platform->mallocD<int>(size(A)).takeValue();
+  std::array<int, arraySize(A)> B;
+  acxxel::DeviceMemory<int> X =
+      Platform->mallocD<int>(arraySize(A)).takeValue();
+  acxxel::DeviceMemory<int> Y =
+      Platform->mallocD<int>(arraySize(A)).takeValue();
   Stream.syncCopyHToD(A, X);
   Stream.syncCopyDToD(X, Y);
   Stream.syncCopyDToH(Y, B);
-  for (size_t I = 0; I < size(A); ++I)
+  for (size_t I = 0; I < arraySize(A); ++I)
     EXPECT_EQ(A[I], B[I]);
   EXPECT_FALSE(Stream.takeStatus().isError());
 }
@@ -194,8 +199,9 @@
 TEST_P(AcxxelTest, AsyncCopyHostAndDevice) {
   acxxel::Platform *Platform = GetParam()().takeValue();
   int A[] = {0, 1, 2};
-  std::array<int, size(A)> B;
-  acxxel::DeviceMemory<int> X = Platform->mallocD<int>(size(A)).takeValue();
+  std::array<int, arraySize(A)> B;
+  acxxel::DeviceMemory<int> X =
+      Platform->mallocD<int>(arraySize(A)).takeValue();
   acxxel::Stream Stream = Platform->createStream().takeValue();
   acxxel::AsyncHostMemory<int> AsyncA =
       Platform->registerHostMem(A).takeValue();
@@ -204,7 +210,7 @@
   EXPECT_FALSE(Stream.asyncCopyHToD(AsyncA, X).takeStatus().isError());
   EXPECT_FALSE(Stream.asyncCopyDToH(X, AsyncB).takeStatus().isError());
   EXPECT_FALSE(Stream.sync().isError());
-  for (size_t I = 0; I < size(A); ++I)
+  for (size_t I = 0; I < arraySize(A); ++I)
     EXPECT_EQ(A[I], B[I]);
 }
 
@@ -280,9 +286,11 @@
 TEST_P(AcxxelTest, AsyncCopyDToD) {
   acxxel::Platform *Platform = GetParam()().takeValue();
   int A[] = {0, 1, 2};
-  std::array<int, size(A)> B;
-  acxxel::DeviceMemory<int> X = Platform->mallocD<int>(size(A)).takeValue();
-  acxxel::DeviceMemory<int> Y = Platform->mallocD<int>(size(A)).takeValue();
+  std::array<int, arraySize(A)> B;
+  acxxel::DeviceMemory<int> X =
+      Platform->mallocD<int>(arraySize(A)).takeValue();
+  acxxel::DeviceMemory<int> Y =
+      Platform->mallocD<int>(arraySize(A)).takeValue();
   acxxel::Stream Stream = Platform->createStream().takeValue();
   acxxel::AsyncHostMemory<int> AsyncA =
       Platform->registerHostMem(A).takeValue();
@@ -292,7 +300,7 @@
   EXPECT_FALSE(Stream.asyncCopyDToD(X, Y).takeStatus().isError());
   EXPECT_FALSE(Stream.asyncCopyDToH(Y, AsyncB).takeStatus().isError());
   EXPECT_FALSE(Stream.sync().isError());
-  for (size_t I = 0; I < size(A); ++I)
+  for (size_t I = 0; I < arraySize(A); ++I)
     EXPECT_EQ(A[I], B[I]);
 }
 
diff --git a/acxxel/tests/multi_device_test.cpp b/acxxel/tests/multi_device_test.cpp
new file mode 100644
index 0000000..1ff8003
--- /dev/null
+++ b/acxxel/tests/multi_device_test.cpp
@@ -0,0 +1,87 @@
+#include "acxxel.h"
+#include "config.h"
+#include "gtest/gtest.h"
+
+namespace {
+
+using PlatformGetter = acxxel::Expected<acxxel::Platform *> (*)();
+class MultiDeviceTest : public ::testing::TestWithParam<PlatformGetter> {};
+
+TEST_P(MultiDeviceTest, AsyncCopy) {
+  acxxel::Platform *Platform = GetParam()().takeValue();
+  int DeviceCount = Platform->getDeviceCount().getValue();
+  EXPECT_GT(DeviceCount, 0);
+
+  int Length = 3;
+  auto A = std::unique_ptr<int[]>(new int[Length]);
+  auto B0 = std::unique_ptr<int[]>(new int[Length]);
+  auto B1 = std::unique_ptr<int[]>(new int[Length]);
+
+  auto ASpan = acxxel::Span<int>(A.get(), Length);
+  auto B0Span = acxxel::Span<int>(B0.get(), Length);
+  auto B1Span = acxxel::Span<int>(B1.get(), Length);
+
+  for (int I = 0; I < Length; ++I)
+    A[I] = I;
+
+  auto AsyncA = Platform->registerHostMem(ASpan).takeValue();
+  auto AsyncB0 = Platform->registerHostMem(B0Span).takeValue();
+  auto AsyncB1 = Platform->registerHostMem(B1Span).takeValue();
+
+  acxxel::Stream Stream0 = Platform->createStream(0).takeValue();
+  acxxel::Stream Stream1 = Platform->createStream(1).takeValue();
+  auto Device0 = Platform->mallocD<int>(Length, 0).takeValue();
+  auto Device1 = Platform->mallocD<int>(Length, 1).takeValue();
+
+  EXPECT_FALSE(Stream0.asyncCopyHToD(AsyncA, Device0, Length)
+                   .asyncCopyDToH(Device0, AsyncB0, Length)
+                   .sync()
+                   .isError());
+
+  EXPECT_FALSE(Stream1.asyncCopyHToD(AsyncA, Device1, Length)
+                   .asyncCopyDToH(Device1, AsyncB1, Length)
+                   .sync()
+                   .isError());
+
+  for (int I = 0; I < Length; ++I) {
+    EXPECT_EQ(B0[I], I);
+    EXPECT_EQ(B1[I], I);
+  }
+}
+
+TEST_P(MultiDeviceTest, Events) {
+  acxxel::Platform *Platform = GetParam()().takeValue();
+  int DeviceCount = Platform->getDeviceCount().getValue();
+  EXPECT_GT(DeviceCount, 0);
+
+  acxxel::Stream Stream0 = Platform->createStream(0).takeValue();
+  acxxel::Stream Stream1 = Platform->createStream(1).takeValue();
+  acxxel::Event Event0 = Platform->createEvent(0).takeValue();
+  acxxel::Event Event1 = Platform->createEvent(1).takeValue();
+
+  EXPECT_FALSE(Stream0.enqueueEvent(Event0).sync().isError());
+  EXPECT_FALSE(Stream1.enqueueEvent(Event1).sync().isError());
+
+  EXPECT_TRUE(Event0.isDone());
+  EXPECT_TRUE(Event1.isDone());
+
+  EXPECT_FALSE(Event0.sync().isError());
+  EXPECT_FALSE(Event1.sync().isError());
+}
+
+#if defined(ACXXEL_ENABLE_CUDA) || defined(ACXXEL_ENABLE_OPENCL)
+INSTANTIATE_TEST_CASE_P(BothPlatformTest, MultiDeviceTest,
+                        ::testing::Values(
+#ifdef ACXXEL_ENABLE_CUDA
+                            acxxel::getCUDAPlatform
+#ifdef ACXXEL_ENABLE_OPENCL
+                            ,
+#endif
+#endif
+#ifdef ACXXEL_ENABLE_OPENCL
+                            acxxel::getOpenCLPlatform
+#endif
+                            ));
+#endif
+
+} // namespace