blob: 5c31fc3da53cd6954db6353d97ba13b292508188 [file] [log] [blame] [edit]
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file contains the definition of the DeviceContext class, which serves
/// as the high-level interface to a particular device (GPU).
///
/// This class provides methods for allocating buffers, loading binaries, and
/// getting and launching kernels on the device.
///
//===----------------------------------------------------------------------===//
#ifndef MATHTEST_DEVICECONTEXT_HPP
#define MATHTEST_DEVICECONTEXT_HPP
#include "mathtest/DeviceResources.hpp"
#include "mathtest/ErrorHandling.hpp"
#include "mathtest/Support.hpp"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility>
namespace mathtest {
const llvm::SetVector<llvm::StringRef> &getPlatforms();
namespace detail {
void allocManagedMemory(ol_device_handle_t DeviceHandle, std::size_t Size,
void **AllocationOut) noexcept;
} // namespace detail
class DeviceContext {
// For simplicity, the current design of this class doesn't have support for
// asynchronous operations and all types of memory allocation.
//
// Other use cases could benefit from operations like enqueued kernel launch
// and enqueued memcpy, as well as device and host memory allocation.
public:
explicit DeviceContext(std::size_t GlobalDeviceId = 0);
explicit DeviceContext(llvm::StringRef Platform, std::size_t DeviceId = 0);
template <typename T>
ManagedBuffer<T> createManagedBuffer(std::size_t Size) const noexcept {
void *UntypedAddress = nullptr;
detail::allocManagedMemory(DeviceHandle, Size * sizeof(T), &UntypedAddress);
T *TypedAddress = static_cast<T *>(UntypedAddress);
return ManagedBuffer<T>(TypedAddress, Size);
}
[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
loadBinary(llvm::StringRef Directory, llvm::StringRef BinaryName) const;
template <typename KernelSignature>
[[nodiscard]] llvm::Expected<DeviceKernel<KernelSignature>>
getKernel(const std::shared_ptr<DeviceImage> &Image,
llvm::StringRef KernelName) const {
assert(Image && "Image provided to getKernel is null");
if (Image->DeviceHandle != DeviceHandle)
return llvm::createStringError(
"Image provided to getKernel was created for a different device");
auto ExpectedHandle = getKernelHandle(Image->Handle, KernelName);
if (!ExpectedHandle)
return ExpectedHandle.takeError();
return DeviceKernel<KernelSignature>(Image, *ExpectedHandle);
}
template <typename KernelSignature, typename... ArgTypes>
void launchKernel(DeviceKernel<KernelSignature> Kernel, uint32_t NumGroups,
uint32_t GroupSize, ArgTypes &&...Args) const noexcept {
using ExpectedTypes =
typename FunctionTypeTraits<KernelSignature>::ArgTypesTuple;
using ProvidedTypes = std::tuple<std::decay_t<ArgTypes>...>;
static_assert(std::is_same_v<ExpectedTypes, ProvidedTypes>,
"Argument types provided to launchKernel do not match the "
"kernel's signature");
if (Kernel.Image->DeviceHandle != DeviceHandle)
FATAL_ERROR("Kernel provided to launchKernel was created for a different "
"device");
if constexpr (sizeof...(Args) == 0) {
launchKernelImpl(Kernel.Handle, NumGroups, GroupSize, nullptr, 0);
} else {
auto KernelArgs = makeKernelArgsPack(std::forward<ArgTypes>(Args)...);
static_assert(
(std::is_trivially_copyable_v<std::decay_t<ArgTypes>> && ...),
"Argument types provided to launchKernel must be trivially copyable");
launchKernelImpl(Kernel.Handle, NumGroups, GroupSize, &KernelArgs,
sizeof(KernelArgs));
}
}
[[nodiscard]] llvm::StringRef getName() const noexcept;
[[nodiscard]] llvm::StringRef getPlatform() const noexcept;
private:
[[nodiscard]] llvm::Expected<ol_symbol_handle_t>
getKernelHandle(ol_program_handle_t ProgramHandle,
llvm::StringRef KernelName) const noexcept;
void launchKernelImpl(ol_symbol_handle_t KernelHandle, uint32_t NumGroups,
uint32_t GroupSize, const void *KernelArgs,
std::size_t KernelArgsSize) const noexcept;
std::size_t GlobalDeviceId;
ol_device_handle_t DeviceHandle;
};
} // namespace mathtest
#endif // MATHTEST_DEVICECONTEXT_HPP