blob: 87a2c7c3885d985a2b77330339a504c39f3cb057 [file] [log] [blame]
//===-- Stream.h - A stream of execution ------------------------*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
///
/// A Stream instance represents a queue of sequential, host-asynchronous work
/// to be performed on a device.
///
/// To enqueue work on a device, first create a Executor instance for a
/// given device and then use that Executor to create a Stream instance.
/// The Stream instance will perform its work on the device managed by the
/// Executor that created it.
///
/// The various "then" methods of the Stream object, such as thenCopyH2D and
/// thenLaunch, may be used to enqueue work on the Stream, and the
/// blockHostUntilDone() method may be used to block the host code until the
/// Stream has completed all its work.
///
/// Multiple Stream instances can be created for the same Executor. This
/// allows several independent streams of computation to be performed
/// simultaneously on a single device.
///
//===----------------------------------------------------------------------===//
#ifndef STREAMEXECUTOR_STREAM_H
#define STREAMEXECUTOR_STREAM_H
#include <cassert>
#include <memory>
#include <string>
#include "streamexecutor/DeviceMemory.h"
#include "streamexecutor/Kernel.h"
#include "streamexecutor/LaunchDimensions.h"
#include "streamexecutor/PackedKernelArgumentArray.h"
#include "streamexecutor/PlatformInterfaces.h"
#include "streamexecutor/Utils/Error.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/RWMutex.h"
namespace streamexecutor {
/// Represents a stream of dependent computations on a device.
///
/// The operations within a stream execute sequentially and asynchronously until
/// blockHostUntilDone() is invoked, which synchronously joins host code with
/// the execution of the stream.
///
/// If any given operation fails when entraining work for the stream, isOK()
/// will indicate that an error has occurred and getStatus() will get the first
/// error that occurred on the stream. There is no way to clear the error state
/// of a stream once it is in an error state.
class Stream {
public:
explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream);
~Stream();
/// Returns whether any error has occurred while entraining work on this
/// stream.
bool isOK() const {
llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex);
return !ErrorMessage;
}
/// Returns the status created by the first error that occurred while
/// entraining work on this stream.
Error getStatus() const {
llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex);
if (ErrorMessage)
return make_error(*ErrorMessage);
else
return Error::success();
};
/// Entrains onto the stream of operations a kernel launch with the given
/// arguments.
///
/// These arguments can be device memory types like GlobalDeviceMemory<T> and
/// SharedDeviceMemory<T>, or they can be primitive types such as int. The
/// allowable argument types are determined by the template parameters to the
/// TypedKernel argument.
template <typename... ParameterTs>
Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize,
const TypedKernel<ParameterTs...> &Kernel,
const ParameterTs &... Arguments) {
auto ArgumentArray =
make_kernel_argument_pack<ParameterTs...>(Arguments...);
setError(PExecutor->launch(ThePlatformStream.get(), BlockSize, GridSize,
Kernel, ArgumentArray));
return *this;
}
/// Enqueues on this stream a command to copy a slice of an array of elements
/// of type T from device to host memory.
///
/// Sets an error if ElementCount is too large for the source or the
/// destination.
///
/// If the Src memory was not created by allocateHostMemory or registered with
/// registerHostMemory, then the copy operation may cause the host and device
/// to block until the copy operation is completed.
template <typename T>
Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src,
llvm::MutableArrayRef<T> Dst, size_t ElementCount) {
if (ElementCount > Src.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", from a device array of element count " +
llvm::Twine(Src.getElementCount()));
else if (ElementCount > Dst.size())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", to a host array of element count " + llvm::Twine(Dst.size()));
else
setError(PExecutor->copyD2H(ThePlatformStream.get(), Src.getBaseMemory(),
Src.getElementOffset() * sizeof(T),
Dst.data(), 0, ElementCount * sizeof(T)));
return *this;
}
/// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
/// llvm::MutableArrayRef<T>, size_t) but does not take an element count
/// argument because it copies the entire source array.
///
/// Sets an error if the Src and Dst sizes do not match.
template <typename T>
Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src,
llvm::MutableArrayRef<T> Dst) {
if (Src.getElementCount() != Dst.size())
setError("array size mismatch for D2H, device source has element count " +
llvm::Twine(Src.getElementCount()) +
" but host destination has element count " +
llvm::Twine(Dst.size()));
else
thenCopyD2H(Src, Dst, Src.getElementCount());
return *this;
}
/// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
/// llvm::MutableArrayRef<T>, size_t) but copies to a pointer rather than an
/// llvm::MutableArrayRef.
///
/// Sets an error if ElementCount is too large for the source slice.
template <typename T>
Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src, T *Dst,
size_t ElementCount) {
thenCopyD2H(Src, llvm::MutableArrayRef<T>(Dst, ElementCount), ElementCount);
return *this;
}
/// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
/// llvm::MutableArrayRef<T>, size_t) but the source is a GlobalDeviceMemory
/// rather than a GlobalDeviceMemorySlice.
template <typename T>
Stream &thenCopyD2H(GlobalDeviceMemory<T> Src, llvm::MutableArrayRef<T> Dst,
size_t ElementCount) {
thenCopyD2H(Src.asSlice(), Dst, ElementCount);
return *this;
}
/// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
/// llvm::MutableArrayRef<T>) but the source is a GlobalDeviceMemory rather
/// than a GlobalDeviceMemorySlice.
template <typename T>
Stream &thenCopyD2H(GlobalDeviceMemory<T> Src, llvm::MutableArrayRef<T> Dst) {
thenCopyD2H(Src.asSlice(), Dst);
return *this;
}
/// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>, T*, size_t) but the
/// source is a GlobalDeviceMemory rather than a GlobalDeviceMemorySlice.
template <typename T>
Stream &thenCopyD2H(GlobalDeviceMemory<T> Src, T *Dst, size_t ElementCount) {
thenCopyD2H(Src.asSlice(), Dst, ElementCount);
return *this;
}
/// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
/// llvm::MutableArrayRef<T>, size_t) but copies from host to device memory
/// rather than device to host memory.
template <typename T>
Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemorySlice<T> Dst,
size_t ElementCount) {
if (ElementCount > Src.size())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", from a host array of element count " +
llvm::Twine(Src.size()));
else if (ElementCount > Dst.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", to a device array of element count " +
llvm::Twine(Dst.getElementCount()));
else
setError(PExecutor->copyH2D(
ThePlatformStream.get(), Src.data(), 0, Dst.getBaseMemory(),
Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
return *this;
}
/// Similar to thenCopyH2D(llvm::ArrayRef<T>, GlobalDeviceMemorySlice<T>,
/// size_t) but does not take an element count argument because it copies the
/// entire source array.
///
/// Sets an error if the Src and Dst sizes do not match.
template <typename T>
Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemorySlice<T> Dst) {
if (Src.size() != Dst.getElementCount())
setError("array size mismatch for H2D, host source has element count " +
llvm::Twine(Src.size()) +
" but device destination has element count " +
llvm::Twine(Dst.getElementCount()));
else
thenCopyH2D(Src, Dst, Dst.getElementCount());
return *this;
}
/// Similar to thenCopyH2D(llvm::ArrayRef<T>, GlobalDeviceMemorySlice<T>,
/// size_t) but copies from a pointer rather than an llvm::ArrayRef.
///
/// Sets an error if ElementCount is too large for the destination.
template <typename T>
Stream &thenCopyH2D(T *Src, GlobalDeviceMemorySlice<T> Dst,
size_t ElementCount) {
thenCopyH2D(llvm::ArrayRef<T>(Src, ElementCount), Dst, ElementCount);
return *this;
}
/// Similar to thenCopyH2D(llvm::ArrayRef<T>, GlobalDeviceMemorySlice<T>,
/// size_t) but the destination is a GlobalDeviceMemory rather than a
/// GlobalDeviceMemorySlice.
template <typename T>
Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemory<T> Dst,
size_t ElementCount) {
thenCopyH2D(Src, Dst.asSlice(), ElementCount);
return *this;
}
/// Similar to thenCopyH2D(llvm::ArrayRef<T>, GlobalDeviceMemorySlice<T>) but
/// the destination is a GlobalDeviceMemory rather than a
/// GlobalDeviceMemorySlice.
template <typename T>
Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemory<T> Dst) {
thenCopyH2D(Src, Dst.asSlice());
return *this;
}
/// Similar to thenCopyH2D(T*, GlobalDeviceMemorySlice<T>, size_t) but the
/// destination is a GlobalDeviceMemory rather than a GlobalDeviceMemorySlice.
template <typename T>
Stream &thenCopyH2D(T *Src, GlobalDeviceMemory<T> Dst, size_t ElementCount) {
thenCopyH2D(Src, Dst.asSlice(), ElementCount);
return *this;
}
/// Similar to thenCopyD2H(GlobalDeviceMemorySlice<T>,
/// llvm::MutableArrayRef<T>, size_t) but copies from one location in device
/// memory to another rather than from device to host memory.
template <typename T>
Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src,
GlobalDeviceMemorySlice<T> Dst, size_t ElementCount) {
if (ElementCount > Src.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", from a device array of element count " +
llvm::Twine(Src.getElementCount()));
else if (ElementCount > Dst.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", to a device array of element count " +
llvm::Twine(Dst.getElementCount()));
else
setError(PExecutor->copyD2D(
ThePlatformStream.get(), Src.getBaseMemory(),
Src.getElementOffset() * sizeof(T), Dst.getBaseMemory(),
Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));
return *this;
}
/// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
/// GlobalDeviceMemorySlice<T>, size_t) but does not take an element count
/// argument because it copies the entire source array.
///
/// Sets an error if the Src and Dst sizes do not match.
template <typename T>
Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src,
GlobalDeviceMemorySlice<T> Dst) {
if (Src.getElementCount() != Dst.getElementCount())
setError("array size mismatch for D2D, device source has element count " +
llvm::Twine(Src.getElementCount()) +
" but device destination has element count " +
llvm::Twine(Dst.getElementCount()));
else
thenCopyD2D(Src, Dst, Src.getElementCount());
return *this;
}
/// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
/// GlobalDeviceMemorySlice<T>, size_t) but the source is a
/// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
template <typename T>
Stream &thenCopyD2D(GlobalDeviceMemory<T> Src, GlobalDeviceMemorySlice<T> Dst,
size_t ElementCount) {
thenCopyD2D(Src.asSlice(), Dst, ElementCount);
return *this;
}
/// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
/// GlobalDeviceMemorySlice<T>) but the source is a GlobalDeviceMemory<T>
/// rather than a GlobalDeviceMemorySlice<T>.
template <typename T>
Stream &thenCopyD2D(GlobalDeviceMemory<T> Src,
GlobalDeviceMemorySlice<T> Dst) {
thenCopyD2D(Src.asSlice(), Dst);
return *this;
}
/// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
/// GlobalDeviceMemorySlice<T>, size_t) but the destination is a
/// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
template <typename T>
Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src, GlobalDeviceMemory<T> Dst,
size_t ElementCount) {
thenCopyD2D(Src, Dst.asSlice(), ElementCount);
return *this;
}
/// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
/// GlobalDeviceMemorySlice<T>) but the destination is a GlobalDeviceMemory<T>
/// rather than a GlobalDeviceMemorySlice<T>.
template <typename T>
Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src,
GlobalDeviceMemory<T> Dst) {
thenCopyD2D(Src, Dst.asSlice());
return *this;
}
/// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
/// GlobalDeviceMemorySlice<T>, size_t) but the source and destination are
/// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
template <typename T>
Stream &thenCopyD2D(GlobalDeviceMemory<T> Src, GlobalDeviceMemory<T> Dst,
size_t ElementCount) {
thenCopyD2D(Src.asSlice(), Dst.asSlice(), ElementCount);
return *this;
}
/// Similar to thenCopyD2D(GlobalDeviceMemorySlice<T>,
/// GlobalDeviceMemorySlice<T>) but the source and destination are
/// GlobalDeviceMemory<T> rather than a GlobalDeviceMemorySlice<T>.
template <typename T>
Stream &thenCopyD2D(GlobalDeviceMemory<T> Src, GlobalDeviceMemory<T> Dst) {
thenCopyD2D(Src.asSlice(), Dst.asSlice());
return *this;
}
private:
/// Sets the error state from an Error object.
///
/// Does not overwrite the error if it is already set.
void setError(Error &&E) {
if (E) {
llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex);
if (!ErrorMessage)
ErrorMessage = consumeAndGetMessage(std::move(E));
}
}
/// Sets the error state from an error message.
///
/// Does not overwrite the error if it is already set.
void setError(llvm::Twine Message) {
llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex);
if (!ErrorMessage)
ErrorMessage = Message.str();
}
/// The PlatformExecutor that supports the operations of this stream.
PlatformExecutor *PExecutor;
/// The platform-specific stream handle for this instance.
std::unique_ptr<PlatformStreamHandle> ThePlatformStream;
/// Mutex that guards the error state flags.
///
/// Mutable so that it can be obtained via const reader lock.
mutable llvm::sys::RWMutex ErrorMessageMutex;
/// First error message for an operation in this stream or empty if there have
/// been no errors.
llvm::Optional<std::string> ErrorMessage;
Stream(const Stream &) = delete;
void operator=(const Stream &) = delete;
};
} // namespace streamexecutor
#endif // STREAMEXECUTOR_STREAM_H