blob: 6c3425f1e17c2564cda92e601d51fbfd8387011d [file] [log] [blame] [edit]
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file contains the implementation of helpers and non-template member
/// functions for the DeviceContext class.
///
//===----------------------------------------------------------------------===//
#include "mathtest/DeviceContext.hpp"
#include "mathtest/ErrorHandling.hpp"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
#include <OffloadAPI.h>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <system_error>
#include <vector>
using namespace mathtest;
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
namespace {
// The static 'Wrapper' instance ensures olInit() is called once at program
// startup and olShutDown() is called once at program termination
struct OffloadInitWrapper {
OffloadInitWrapper() { OL_CHECK(olInit()); }
~OffloadInitWrapper() { OL_CHECK(olShutDown()); }
};
static OffloadInitWrapper Wrapper{};
[[nodiscard]] std::string getDeviceName(ol_device_handle_t DeviceHandle) {
std::size_t PropSize = 0;
OL_CHECK(olGetDeviceInfoSize(DeviceHandle, OL_DEVICE_INFO_PRODUCT_NAME,
&PropSize));
if (PropSize == 0)
return "";
std::string PropValue(PropSize, '\0');
OL_CHECK(olGetDeviceInfo(DeviceHandle, OL_DEVICE_INFO_PRODUCT_NAME, PropSize,
PropValue.data()));
PropValue.pop_back(); // Remove the null terminator
return PropValue;
}
[[nodiscard]] ol_platform_handle_t
getDevicePlatform(ol_device_handle_t DeviceHandle) noexcept {
ol_platform_handle_t PlatformHandle = nullptr;
OL_CHECK(olGetDeviceInfo(DeviceHandle, OL_DEVICE_INFO_PLATFORM,
sizeof(PlatformHandle), &PlatformHandle));
return PlatformHandle;
}
[[nodiscard]] std::string getPlatformName(ol_platform_handle_t PlatformHandle) {
std::size_t PropSize = 0;
OL_CHECK(
olGetPlatformInfoSize(PlatformHandle, OL_PLATFORM_INFO_NAME, &PropSize));
if (PropSize == 0)
return "";
std::string PropValue(PropSize, '\0');
OL_CHECK(olGetPlatformInfo(PlatformHandle, OL_PLATFORM_INFO_NAME, PropSize,
PropValue.data()));
PropValue.pop_back(); // Remove the null terminator
return PropValue;
}
[[nodiscard]] ol_platform_backend_t
getPlatformBackend(ol_platform_handle_t PlatformHandle) noexcept {
ol_platform_backend_t Backend = OL_PLATFORM_BACKEND_UNKNOWN;
OL_CHECK(olGetPlatformInfo(PlatformHandle, OL_PLATFORM_INFO_BACKEND,
sizeof(Backend), &Backend));
return Backend;
}
struct Device {
ol_device_handle_t Handle;
std::string Name;
std::string Platform;
ol_platform_backend_t Backend;
};
const std::vector<Device> &getDevices() {
// Thread-safe initialization of a static local variable
static auto Devices = []() {
std::vector<Device> TmpDevices;
// Discovers all devices that are not the host
const auto *const ResultFromIterate = olIterateDevices(
[](ol_device_handle_t DeviceHandle, void *Data) {
ol_platform_handle_t PlatformHandle = getDevicePlatform(DeviceHandle);
ol_platform_backend_t Backend = getPlatformBackend(PlatformHandle);
if (Backend != OL_PLATFORM_BACKEND_HOST) {
auto Name = getDeviceName(DeviceHandle);
auto Platform = getPlatformName(PlatformHandle);
static_cast<std::vector<Device> *>(Data)->push_back(
{DeviceHandle, Name, Platform, Backend});
}
return true;
},
&TmpDevices);
OL_CHECK(ResultFromIterate);
return TmpDevices;
}();
return Devices;
}
} // namespace
const llvm::SetVector<llvm::StringRef> &mathtest::getPlatforms() {
// Thread-safe initialization of a static local variable
static auto Platforms = []() {
llvm::SetVector<llvm::StringRef> TmpPlatforms;
for (const auto &Device : getDevices())
TmpPlatforms.insert(Device.Platform);
return TmpPlatforms;
}();
return Platforms;
}
void detail::allocManagedMemory(ol_device_handle_t DeviceHandle,
std::size_t Size,
void **AllocationOut) noexcept {
OL_CHECK(
olMemAlloc(DeviceHandle, OL_ALLOC_TYPE_MANAGED, Size, AllocationOut));
}
//===----------------------------------------------------------------------===//
// DeviceContext
//===----------------------------------------------------------------------===//
DeviceContext::DeviceContext(std::size_t GlobalDeviceId)
: GlobalDeviceId(GlobalDeviceId), DeviceHandle(nullptr) {
const auto &Devices = getDevices();
if (GlobalDeviceId >= Devices.size())
FATAL_ERROR("Invalid GlobalDeviceId: " + llvm::Twine(GlobalDeviceId) +
", but the number of available devices is " +
llvm::Twine(Devices.size()));
DeviceHandle = Devices[GlobalDeviceId].Handle;
}
DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId)
: DeviceHandle(nullptr) {
const auto &Platforms = getPlatforms();
if (!llvm::any_of(Platforms, [&](llvm::StringRef CurrentPlatform) {
return CurrentPlatform.equals_insensitive(Platform);
}))
FATAL_ERROR("There is no platform that matches with '" +
llvm::Twine(Platform) +
"'. Available platforms are: " + llvm::join(Platforms, ", "));
const auto &Devices = getDevices();
std::optional<std::size_t> FoundGlobalDeviceId;
std::size_t MatchCount = 0;
for (std::size_t Index = 0; Index < Devices.size(); ++Index) {
if (Platform.equals_insensitive(Devices[Index].Platform)) {
if (MatchCount == DeviceId) {
FoundGlobalDeviceId = Index;
break;
}
MatchCount++;
}
}
if (!FoundGlobalDeviceId)
FATAL_ERROR("Invalid DeviceId: " + llvm::Twine(DeviceId) +
", but the number of available devices on '" + Platform +
"' is " + llvm::Twine(MatchCount));
GlobalDeviceId = *FoundGlobalDeviceId;
DeviceHandle = Devices[GlobalDeviceId].Handle;
}
[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
DeviceContext::loadBinary(llvm::StringRef Directory,
llvm::StringRef BinaryName) const {
auto Backend = getDevices()[GlobalDeviceId].Backend;
llvm::StringRef Extension;
switch (Backend) {
case OL_PLATFORM_BACKEND_AMDGPU:
Extension = ".amdgpu.bin";
break;
case OL_PLATFORM_BACKEND_CUDA:
Extension = ".nvptx64.bin";
break;
default:
return llvm::createStringError(
"Unsupported backend to infer binary extension");
}
llvm::SmallString<128> FullPath(Directory);
llvm::sys::path::append(FullPath, llvm::Twine(BinaryName) + Extension);
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
llvm::MemoryBuffer::getFile(FullPath);
if (std::error_code ErrorCode = FileOrErr.getError())
return llvm::createStringError(
llvm::Twine("Failed to read device binary file '") + FullPath +
"': " + ErrorCode.message());
std::unique_ptr<llvm::MemoryBuffer> &BinaryData = *FileOrErr;
ol_program_handle_t ProgramHandle = nullptr;
const ol_result_t OlResult =
olCreateProgram(DeviceHandle, BinaryData->getBufferStart(),
BinaryData->getBufferSize(), &ProgramHandle);
if (OlResult != OL_SUCCESS) {
llvm::StringRef Details =
OlResult->Details ? OlResult->Details : "No details provided";
// clang-format off
return llvm::createStringError(
llvm::Twine(Details) +
" (code " + llvm::Twine(OlResult->Code) + ")");
// clang-format on
}
return std::shared_ptr<DeviceImage>(
new DeviceImage(DeviceHandle, ProgramHandle));
}
[[nodiscard]] llvm::Expected<ol_symbol_handle_t>
DeviceContext::getKernelHandle(ol_program_handle_t ProgramHandle,
llvm::StringRef KernelName) const noexcept {
ol_symbol_handle_t Handle = nullptr;
llvm::SmallString<32> NameBuffer(KernelName);
const ol_result_t OlResult = olGetSymbol(ProgramHandle, NameBuffer.c_str(),
OL_SYMBOL_KIND_KERNEL, &Handle);
if (OlResult != OL_SUCCESS) {
llvm::StringRef Details =
OlResult->Details ? OlResult->Details : "No details provided";
// clang-format off
return llvm::createStringError(
llvm::Twine(Details) +
" (code " + llvm::Twine(OlResult->Code) + ")");
// clang-format on
}
return Handle;
}
void DeviceContext::launchKernelImpl(
ol_symbol_handle_t KernelHandle, uint32_t NumGroups, uint32_t GroupSize,
const void *KernelArgs, std::size_t KernelArgsSize) const noexcept {
ol_kernel_launch_size_args_t LaunchSizeArgs;
LaunchSizeArgs.Dimensions = 1;
LaunchSizeArgs.NumGroups = {NumGroups, 1, 1};
LaunchSizeArgs.GroupSize = {GroupSize, 1, 1};
LaunchSizeArgs.DynSharedMemory = 0;
OL_CHECK(olLaunchKernel(nullptr, DeviceHandle, KernelHandle, KernelArgs,
KernelArgsSize, &LaunchSizeArgs));
}
[[nodiscard]] llvm::StringRef DeviceContext::getName() const noexcept {
return getDevices()[GlobalDeviceId].Name;
}
[[nodiscard]] llvm::StringRef DeviceContext::getPlatform() const noexcept {
return getDevices()[GlobalDeviceId].Platform;
}