blob: bdff7ff9701ec32931836903e94453fa8153f970 [file] [log] [blame]
//===-- Stream.h - A stream of execution ------------------------*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
///
/// A Stream instance represents a queue of sequential, host-asynchronous work
/// to be performed on a device.
///
/// To enqueue work on a device, first create a Device instance then use that
/// Device to create a Stream instance. The Stream instance will perform its
/// work on the device managed by the Device object that created it.
///
/// The various "then" methods of the Stream object, such as thenCopyH2D and
/// thenLaunch, may be used to enqueue work on the Stream, and the
/// blockHostUntilDone() method may be used to block the host code until the
/// Stream has completed all its work.
///
/// Multiple Stream instances can be created for the same Device. This allows
/// several independent streams of computation to be performed simultaneously on
/// a single device.
///
//===----------------------------------------------------------------------===//
#ifndef STREAMEXECUTOR_STREAM_H
#define STREAMEXECUTOR_STREAM_H
#include <cassert>
#include <memory>
#include <string>
#include <type_traits>
#include "streamexecutor/DeviceMemory.h"
#include "streamexecutor/Error.h"
#include "streamexecutor/HostMemory.h"
#include "streamexecutor/Kernel.h"
#include "streamexecutor/LaunchDimensions.h"
#include "streamexecutor/PackedKernelArgumentArray.h"
#include "streamexecutor/PlatformDevice.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/RWMutex.h"
namespace streamexecutor {
/// Represents a stream of dependent computations on a device.
///
/// The operations within a stream execute sequentially and asynchronously until
/// blockHostUntilDone() is invoked, which synchronously joins host code with
/// the execution of the stream.
///
/// If any given operation fails when entraining work for the stream, isOK()
/// will indicate that an error has occurred and getStatus() will get the first
/// error that occurred on the stream. There is no way to clear the error state
/// of a stream once it is in an error state.
class Stream {
public:
Stream(PlatformDevice *D, const void *PlatformStreamHandle);
Stream(const Stream &Other) = delete;
Stream &operator=(const Stream &Other) = delete;
Stream(Stream &&Other) noexcept;
Stream &operator=(Stream &&Other) noexcept;
~Stream();
/// Returns whether any error has occurred while entraining work on this
/// stream.
bool isOK() const {
llvm::sys::ScopedReader ReaderLock(*ErrorMessageMutex);
return !ErrorMessage;
}
/// Returns the status created by the first error that occurred while
/// entraining work on this stream.
Error getStatus() const {
llvm::sys::ScopedReader ReaderLock(*ErrorMessageMutex);
if (ErrorMessage)
return make_error(*ErrorMessage);
else
return Error::success();
}
// Blocks the calling host thread until all work enqueued on this Stream
// completes.
//
// Returns the result of getStatus() after the Stream work completes.
Error blockHostUntilDone() {
setError(PDevice->blockHostUntilDone(PlatformStreamHandle));
return getStatus();
}
/// Entrains onto the stream of operations a kernel launch with the given
/// arguments.
///
/// These arguments can be device memory types like GlobalDeviceMemory<T> and
/// SharedDeviceMemory<T>, or they can be primitive types such as int. The
/// allowable argument types are determined by the template parameters to the
/// Kernel argument.
template <typename... ParameterTs>
Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize,
const Kernel<ParameterTs...> &K,
const ParameterTs &... Arguments) {
auto ArgumentArray =
make_kernel_argument_pack<ParameterTs...>(Arguments...);
setError(PDevice->launch(PlatformStreamHandle, BlockSize, GridSize,
K.getPlatformHandle(), ArgumentArray));
return *this;
}
/// \name Device memory copying functions
///
/// These methods enqueue a device memory copy operation on the stream and
/// return without waiting for the operation to complete.
///
/// The arguments and bounds checking for these methods match the API of the
/// \ref DeviceHostSyncCopyGroup
/// "host-synchronous device memory copying functions" of Device.
///
/// The template types SrcTy and DstTy must match the following constraints:
/// * Must define typename ElementTy (the type of element stored in the
/// memory);
/// * ElementTy for the source argument must be the same as ElementTy for
/// the destination argument;
/// * Must be convertible to the correct slice type:
/// * GlobalDeviceMemorySlice<ElementTy> for device memory arguments,
/// * RegisteredHostMemorySlice<ElementTy> for host memory source
/// arguments,
/// * MutableRegisteredHostMemorySlice<ElementT> for host memory
/// destination arguments.
///@{
// D2H
template <typename SrcTy, typename DstTy>
Stream &thenCopyD2H(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) {
using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy;
using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy;
static_assert(std::is_same<SrcElemTy, DstElemTy>::value,
"src/dst element type mismatch for thenCopyD2H");
GlobalDeviceMemorySlice<SrcElemTy> SrcSlice(Src);
MutableRegisteredHostMemorySlice<DstElemTy> DstSlice(Dst);
if (ElementCount > Src.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", from a device array of element count " +
llvm::Twine(SrcSlice.getElementCount()));
else if (ElementCount > DstSlice.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", to a host array of element count " +
llvm::Twine(DstSlice.getElementCount()));
else
setError(PDevice->copyD2H(
PlatformStreamHandle, SrcSlice.getBaseMemory().getHandle(),
SrcSlice.getElementOffset() * sizeof(SrcElemTy),
DstSlice.getPointer(), 0, ElementCount * sizeof(DstElemTy)));
return *this;
}
template <typename SrcTy, typename DstTy>
Stream &thenCopyD2H(SrcTy &&Src, DstTy &&Dst) {
using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy;
using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy;
static_assert(std::is_same<SrcElemTy, DstElemTy>::value,
"src/dst element type mismatch for thenCopyD2H");
GlobalDeviceMemorySlice<SrcElemTy> SrcSlice(Src);
MutableRegisteredHostMemorySlice<DstElemTy> DstSlice(Dst);
if (SrcSlice.getElementCount() != DstSlice.getElementCount())
setError("array size mismatch for D2H, device source has element count " +
llvm::Twine(SrcSlice.getElementCount()) +
" but host destination has element count " +
llvm::Twine(DstSlice.getElementCount()));
else
thenCopyD2H(SrcSlice, DstSlice, SrcSlice.getElementCount());
return *this;
}
// H2D
template <typename SrcTy, typename DstTy>
Stream &thenCopyH2D(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) {
using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy;
using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy;
static_assert(std::is_same<SrcElemTy, DstElemTy>::value,
"src/dst element type mismatch for thenCopyH2D");
RegisteredHostMemorySlice<SrcElemTy> SrcSlice(Src);
GlobalDeviceMemorySlice<DstElemTy> DstSlice(Dst);
if (ElementCount > SrcSlice.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", from a host array of element count " +
llvm::Twine(SrcSlice.getElementCount()));
else if (ElementCount > DstSlice.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", to a device array of element count " +
llvm::Twine(DstSlice.getElementCount()));
else
setError(PDevice->copyH2D(PlatformStreamHandle, SrcSlice.getPointer(), 0,
DstSlice.getBaseMemory().getHandle(),
DstSlice.getElementOffset() * sizeof(DstElemTy),
ElementCount * sizeof(SrcElemTy)));
return *this;
}
template <typename SrcTy, typename DstTy>
Stream &thenCopyH2D(SrcTy &&Src, DstTy &&Dst) {
using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy;
using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy;
static_assert(std::is_same<SrcElemTy, DstElemTy>::value,
"src/dst element type mismatch for thenCopyH2D");
RegisteredHostMemorySlice<SrcElemTy> SrcSlice(Src);
GlobalDeviceMemorySlice<DstElemTy> DstSlice(Dst);
if (SrcSlice.getElementCount() != DstSlice.getElementCount())
setError("array size mismatch for H2D, host source has element count " +
llvm::Twine(SrcSlice.getElementCount()) +
" but device destination has element count " +
llvm::Twine(DstSlice.getElementCount()));
else
thenCopyH2D(SrcSlice, DstSlice, DstSlice.getElementCount());
return *this;
}
// D2D
template <typename SrcTy, typename DstTy>
Stream &thenCopyD2D(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) {
using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy;
using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy;
static_assert(std::is_same<SrcElemTy, DstElemTy>::value,
"src/dst element type mismatch for thenCopyD2D");
GlobalDeviceMemorySlice<SrcElemTy> SrcSlice(Src);
GlobalDeviceMemorySlice<DstElemTy> DstSlice(Dst);
if (ElementCount > SrcSlice.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", from a device array of element count " +
llvm::Twine(SrcSlice.getElementCount()));
else if (ElementCount > DstSlice.getElementCount())
setError("copying too many elements, " + llvm::Twine(ElementCount) +
", to a device array of element count " +
llvm::Twine(DstSlice.getElementCount()));
else
setError(PDevice->copyD2D(PlatformStreamHandle,
SrcSlice.getBaseMemory().getHandle(),
SrcSlice.getElementOffset() * sizeof(SrcElemTy),
DstSlice.getBaseMemory().getHandle(),
DstSlice.getElementOffset() * sizeof(DstElemTy),
ElementCount * sizeof(SrcElemTy)));
return *this;
}
template <typename SrcTy, typename DstTy>
Stream &thenCopyD2D(SrcTy &&Src, DstTy &&Dst) {
using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy;
using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy;
static_assert(std::is_same<SrcElemTy, DstElemTy>::value,
"src/dst element type mismatch for thenCopyD2D");
GlobalDeviceMemorySlice<SrcElemTy> SrcSlice(Src);
GlobalDeviceMemorySlice<DstElemTy> DstSlice(Dst);
if (SrcSlice.getElementCount() != DstSlice.getElementCount())
setError("array size mismatch for D2D, device source has element count " +
llvm::Twine(SrcSlice.getElementCount()) +
" but device destination has element count " +
llvm::Twine(DstSlice.getElementCount()));
else
thenCopyD2D(SrcSlice, DstSlice, SrcSlice.getElementCount());
return *this;
}
///@} End device memory copying functions
private:
/// Sets the error state from an Error object.
///
/// Does not overwrite the error if it is already set.
void setError(Error &&E) {
if (E) {
llvm::sys::ScopedWriter WriterLock(*ErrorMessageMutex);
if (!ErrorMessage)
ErrorMessage = consumeAndGetMessage(std::move(E));
}
}
/// Sets the error state from an error message.
///
/// Does not overwrite the error if it is already set.
void setError(const llvm::Twine &Message) {
llvm::sys::ScopedWriter WriterLock(*ErrorMessageMutex);
if (!ErrorMessage)
ErrorMessage = Message.str();
}
/// The PlatformDevice that supports the operations of this stream.
PlatformDevice *PDevice;
/// The platform-specific stream handle for this instance.
const void *PlatformStreamHandle;
/// Mutex that guards the error state flags.
std::unique_ptr<llvm::sys::RWMutex> ErrorMessageMutex;
/// First error message for an operation in this stream or empty if there have
/// been no errors.
llvm::Optional<std::string> ErrorMessage;
};
} // namespace streamexecutor
#endif // STREAMEXECUTOR_STREAM_H