| //===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This contains the definitions of the new LLVM/Offload API entry points. See |
| // new-api/API/README.md for more information. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "OffloadImpl.hpp" |
| #include "Helpers.hpp" |
| #include "OffloadPrint.hpp" |
| #include "PluginManager.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include <OffloadAPI.h> |
| |
| #include <mutex> |
| |
| // TODO: Some plugins expect to be linked into libomptarget which defines these |
| // symbols to implement ompt callbacks. The least invasive workaround here is to |
| // define them in libLLVMOffload as false/null so they are never used. In future |
| // it would be better to allow the plugins to implement callbacks without |
| // pulling in details from libomptarget. |
| #ifdef OMPT_SUPPORT |
| namespace llvm::omp::target { |
| namespace ompt { |
| bool Initialized = false; |
| ompt_get_callback_t lookupCallbackByCode = nullptr; |
| ompt_function_lookup_t lookupCallbackByName = nullptr; |
| } // namespace ompt |
| } // namespace llvm::omp::target |
| #endif |
| |
| using namespace llvm::omp::target; |
| using namespace llvm::omp::target::plugin; |
| using namespace error; |
| |
| // Handle type definitions. Ideally these would be 1:1 with the plugins, but |
| // we add some additional data here for now to avoid churn in the plugin |
| // interface. |
| struct ol_device_impl_t { |
| ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device, |
| ol_platform_handle_t Platform, InfoTreeNode &&DevInfo) |
| : DeviceNum(DeviceNum), Device(Device), Platform(Platform), |
| Info(std::forward<InfoTreeNode>(DevInfo)) {} |
| int DeviceNum; |
| GenericDeviceTy *Device; |
| ol_platform_handle_t Platform; |
| InfoTreeNode Info; |
| }; |
| |
| struct ol_platform_impl_t { |
| ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin, |
| ol_platform_backend_t BackendType) |
| : Plugin(std::move(Plugin)), BackendType(BackendType) {} |
| std::unique_ptr<GenericPluginTy> Plugin; |
| std::vector<ol_device_impl_t> Devices; |
| ol_platform_backend_t BackendType; |
| }; |
| |
| struct ol_queue_impl_t { |
| ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) |
| : AsyncInfo(AsyncInfo), Device(Device) {} |
| __tgt_async_info *AsyncInfo; |
| ol_device_handle_t Device; |
| }; |
| |
| struct ol_event_impl_t { |
| ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue) |
| : EventInfo(EventInfo), Queue(Queue) {} |
| // EventInfo may be null, in which case the event should be considered always |
| // complete |
| void *EventInfo; |
| ol_queue_handle_t Queue; |
| }; |
| |
| struct ol_program_impl_t { |
| ol_program_impl_t(plugin::DeviceImageTy *Image, |
| std::unique_ptr<llvm::MemoryBuffer> ImageData, |
| const __tgt_device_image &DeviceImage) |
| : Image(Image), ImageData(std::move(ImageData)), |
| DeviceImage(DeviceImage) {} |
| plugin::DeviceImageTy *Image; |
| std::unique_ptr<llvm::MemoryBuffer> ImageData; |
| std::mutex SymbolListMutex; |
| __tgt_device_image DeviceImage; |
| llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols; |
| llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols; |
| }; |
| |
| struct ol_symbol_impl_t { |
| ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel) |
| : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {} |
| ol_symbol_impl_t(const char *Name, GlobalTy &&Global) |
| : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {} |
| std::variant<GenericKernelTy *, GlobalTy> PluginImpl; |
| ol_symbol_kind_t Kind; |
| llvm::StringRef Name; |
| }; |
| |
| namespace llvm { |
| namespace offload { |
| |
| struct AllocInfo { |
| ol_device_handle_t Device; |
| ol_alloc_type_t Type; |
| }; |
| |
| // Global shared state for liboffload |
| struct OffloadContext; |
| // This pointer is non-null if and only if the context is valid and fully |
| // initialized |
| static std::atomic<OffloadContext *> OffloadContextVal; |
| std::mutex OffloadContextValMutex; |
| struct OffloadContext { |
| OffloadContext(OffloadContext &) = delete; |
| OffloadContext(OffloadContext &&) = delete; |
| OffloadContext &operator=(OffloadContext &) = delete; |
| OffloadContext &operator=(OffloadContext &&) = delete; |
| |
| bool TracingEnabled = false; |
| bool ValidationEnabled = true; |
| DenseMap<void *, AllocInfo> AllocInfoMap{}; |
| SmallVector<ol_platform_impl_t, 4> Platforms{}; |
| size_t RefCount; |
| |
| ol_device_handle_t HostDevice() { |
| // The host platform is always inserted last |
| return &Platforms.back().Devices[0]; |
| } |
| |
| static OffloadContext &get() { |
| assert(OffloadContextVal); |
| return *OffloadContextVal; |
| } |
| }; |
| |
| // If the context is uninited, then we assume tracing is disabled |
| bool isTracingEnabled() { |
| return isOffloadInitialized() && OffloadContext::get().TracingEnabled; |
| } |
| bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; } |
| bool isOffloadInitialized() { return OffloadContextVal != nullptr; } |
| |
| template <typename HandleT> Error olDestroy(HandleT Handle) { |
| delete Handle; |
| return Error::success(); |
| } |
| |
| constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { |
| if (Name == "amdgpu") { |
| return OL_PLATFORM_BACKEND_AMDGPU; |
| } else if (Name == "cuda") { |
| return OL_PLATFORM_BACKEND_CUDA; |
| } else { |
| return OL_PLATFORM_BACKEND_UNKNOWN; |
| } |
| } |
| |
| // Every plugin exports this method to create an instance of the plugin type. |
| #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); |
| #include "Shared/Targets.def" |
| |
| Error initPlugins(OffloadContext &Context) { |
| // Attempt to create an instance of each supported plugin. |
| #define PLUGIN_TARGET(Name) \ |
| do { \ |
| Context.Platforms.emplace_back(ol_platform_impl_t{ \ |
| std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \ |
| pluginNameToBackend(#Name)}); \ |
| } while (false); |
| #include "Shared/Targets.def" |
| |
| // Preemptively initialize all devices in the plugin |
| for (auto &Platform : Context.Platforms) { |
| // Do not use the host plugin - it isn't supported. |
| if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN) |
| continue; |
| auto Err = Platform.Plugin->init(); |
| [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); |
| for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); |
| DevNum++) { |
| if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { |
| auto Device = &Platform.Plugin->getDevice(DevNum); |
| auto Info = Device->obtainInfoImpl(); |
| if (auto Err = Info.takeError()) |
| return Err; |
| Platform.Devices.emplace_back(DevNum, Device, &Platform, |
| std::move(*Info)); |
| } |
| } |
| } |
| |
| // Add the special host device |
| auto &HostPlatform = Context.Platforms.emplace_back( |
| ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); |
| HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{}); |
| Context.HostDevice()->Platform = &HostPlatform; |
| |
| Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); |
| Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); |
| |
| return Plugin::success(); |
| } |
| |
| Error olInit_impl() { |
| std::lock_guard<std::mutex> Lock(OffloadContextValMutex); |
| |
| if (isOffloadInitialized()) { |
| OffloadContext::get().RefCount++; |
| return Plugin::success(); |
| } |
| |
| // Use a temporary to ensure that entry points querying OffloadContextVal do |
| // not get a partially initialized context |
| auto *NewContext = new OffloadContext{}; |
| Error InitResult = initPlugins(*NewContext); |
| OffloadContextVal.store(NewContext); |
| OffloadContext::get().RefCount++; |
| |
| return InitResult; |
| } |
| |
| Error olShutDown_impl() { |
| std::lock_guard<std::mutex> Lock(OffloadContextValMutex); |
| |
| if (--OffloadContext::get().RefCount != 0) |
| return Error::success(); |
| |
| llvm::Error Result = Error::success(); |
| auto *OldContext = OffloadContextVal.exchange(nullptr); |
| |
| for (auto &P : OldContext->Platforms) { |
| // Host plugin is nullptr and has no deinit |
| if (!P.Plugin || !P.Plugin->is_initialized()) |
| continue; |
| |
| if (auto Res = P.Plugin->deinit()) |
| Result = llvm::joinErrors(std::move(Result), std::move(Res)); |
| } |
| |
| delete OldContext; |
| return Result; |
| } |
| |
| Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, |
| ol_platform_info_t PropName, size_t PropSize, |
| void *PropValue, size_t *PropSizeRet) { |
| InfoWriter Info(PropSize, PropValue, PropSizeRet); |
| bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; |
| |
| switch (PropName) { |
| case OL_PLATFORM_INFO_NAME: |
| return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName()); |
| case OL_PLATFORM_INFO_VENDOR_NAME: |
| // TODO: Implement this |
| return Info.writeString("Unknown platform vendor"); |
| case OL_PLATFORM_INFO_VERSION: { |
| return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, |
| OL_VERSION_MINOR, OL_VERSION_PATCH) |
| .str()); |
| } |
| case OL_PLATFORM_INFO_BACKEND: { |
| return Info.write<ol_platform_backend_t>(Platform->BackendType); |
| } |
| default: |
| return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
| "getPlatformInfo enum '%i' is invalid", PropName); |
| } |
| |
| return Error::success(); |
| } |
| |
| Error olGetPlatformInfo_impl(ol_platform_handle_t Platform, |
| ol_platform_info_t PropName, size_t PropSize, |
| void *PropValue) { |
| return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue, |
| nullptr); |
| } |
| |
| Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, |
| ol_platform_info_t PropName, |
| size_t *PropSizeRet) { |
| return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr, |
| PropSizeRet); |
| } |
| |
| Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, |
| ol_device_info_t PropName, size_t PropSize, |
| void *PropValue, size_t *PropSizeRet) { |
| assert(Device != OffloadContext::get().HostDevice()); |
| InfoWriter Info(PropSize, PropValue, PropSizeRet); |
| |
| auto makeError = [&](ErrorCode Code, StringRef Err) { |
| std::string ErrBuffer; |
| llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err; |
| return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str()); |
| }; |
| |
| // These are not implemented by the plugin interface |
| switch (PropName) { |
| case OL_DEVICE_INFO_PLATFORM: |
| return Info.write<void *>(Device->Platform); |
| |
| case OL_DEVICE_INFO_TYPE: |
| return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU); |
| |
| case OL_DEVICE_INFO_SINGLE_FP_CONFIG: |
| case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: { |
| ol_device_fp_capability_flags_t flags{0}; |
| flags |= OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT | |
| OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST | |
| OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO | |
| OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF | |
| OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN | |
| OL_DEVICE_FP_CAPABILITY_FLAG_DENORM | |
| OL_DEVICE_FP_CAPABILITY_FLAG_FMA; |
| return Info.write(flags); |
| } |
| |
| case OL_DEVICE_INFO_HALF_FP_CONFIG: |
| return Info.write<ol_device_fp_capability_flags_t>(0); |
| |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE: |
| return Info.write<uint32_t>(1); |
| |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF: |
| return Info.write<uint32_t>(0); |
| |
| // None of the existing plugins specify a limit on a single allocation, |
| // so return the global memory size instead |
| case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE: |
| PropName = OL_DEVICE_INFO_GLOBAL_MEM_SIZE; |
| break; |
| |
| default: |
| break; |
| } |
| |
| if (PropName >= OL_DEVICE_INFO_LAST) |
| return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
| "getDeviceInfo enum '%i' is invalid", PropName); |
| |
| auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName)); |
| if (!EntryOpt) |
| return makeError(ErrorCode::UNIMPLEMENTED, |
| "plugin did not provide a response for this information"); |
| auto Entry = *EntryOpt; |
| |
| // Retrieve properties from the plugin interface |
| switch (PropName) { |
| case OL_DEVICE_INFO_NAME: |
| case OL_DEVICE_INFO_VENDOR: |
| case OL_DEVICE_INFO_DRIVER_VERSION: { |
| // String values |
| if (!std::holds_alternative<std::string>(Entry->Value)) |
| return makeError(ErrorCode::BACKEND_FAILURE, |
| "plugin returned incorrect type"); |
| return Info.writeString(std::get<std::string>(Entry->Value).c_str()); |
| } |
| |
| case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: { |
| // Uint64 values |
| if (!std::holds_alternative<uint64_t>(Entry->Value)) |
| return makeError(ErrorCode::BACKEND_FAILURE, |
| "plugin returned incorrect type"); |
| return Info.write(std::get<uint64_t>(Entry->Value)); |
| } |
| |
| case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: |
| case OL_DEVICE_INFO_VENDOR_ID: |
| case OL_DEVICE_INFO_NUM_COMPUTE_UNITS: |
| case OL_DEVICE_INFO_ADDRESS_BITS: |
| case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY: |
| case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: { |
| // Uint32 values |
| if (!std::holds_alternative<uint64_t>(Entry->Value)) |
| return makeError(ErrorCode::BACKEND_FAILURE, |
| "plugin returned incorrect type"); |
| auto Value = std::get<uint64_t>(Entry->Value); |
| if (Value > std::numeric_limits<uint32_t>::max()) |
| return makeError(ErrorCode::BACKEND_FAILURE, |
| "plugin returned out of range device info"); |
| return Info.write(static_cast<uint32_t>(Value)); |
| } |
| |
| case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: { |
| // {x, y, z} triples |
| ol_dimensions_t Out{0, 0, 0}; |
| |
| auto getField = [&](StringRef Name, uint32_t &Dest) { |
| if (auto F = Entry->get(Name)) { |
| if (!std::holds_alternative<size_t>((*F)->Value)) |
| return makeError( |
| ErrorCode::BACKEND_FAILURE, |
| "plugin returned incorrect type for dimensions element"); |
| Dest = std::get<size_t>((*F)->Value); |
| } else |
| return makeError(ErrorCode::BACKEND_FAILURE, |
| "plugin didn't provide all values for dimensions"); |
| return Plugin::success(); |
| }; |
| |
| if (auto Res = getField("x", Out.x)) |
| return Res; |
| if (auto Res = getField("y", Out.y)) |
| return Res; |
| if (auto Res = getField("z", Out.z)) |
| return Res; |
| |
| return Info.write(Out); |
| } |
| |
| default: |
| llvm_unreachable("Unimplemented device info"); |
| } |
| } |
| |
| Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, |
| ol_device_info_t PropName, size_t PropSize, |
| void *PropValue, size_t *PropSizeRet) { |
| assert(Device == OffloadContext::get().HostDevice()); |
| InfoWriter Info(PropSize, PropValue, PropSizeRet); |
| |
| switch (PropName) { |
| case OL_DEVICE_INFO_PLATFORM: |
| return Info.write<void *>(Device->Platform); |
| case OL_DEVICE_INFO_TYPE: |
| return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST); |
| case OL_DEVICE_INFO_NAME: |
| return Info.writeString("Virtual Host Device"); |
| case OL_DEVICE_INFO_VENDOR: |
| return Info.writeString("Liboffload"); |
| case OL_DEVICE_INFO_DRIVER_VERSION: |
| return Info.writeString(LLVM_VERSION_STRING); |
| case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: |
| return Info.write<uint32_t>(1); |
| case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: |
| return Info.write<ol_dimensions_t>(ol_dimensions_t{1, 1, 1}); |
| case OL_DEVICE_INFO_VENDOR_ID: |
| return Info.write<uint32_t>(0); |
| case OL_DEVICE_INFO_NUM_COMPUTE_UNITS: |
| return Info.write<uint32_t>(1); |
| case OL_DEVICE_INFO_SINGLE_FP_CONFIG: |
| case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: |
| return Info.write<ol_device_fp_capability_flags_t>( |
| OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT | |
| OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST | |
| OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO | |
| OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF | |
| OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN | |
| OL_DEVICE_FP_CAPABILITY_FLAG_DENORM | OL_DEVICE_FP_CAPABILITY_FLAG_FMA); |
| case OL_DEVICE_INFO_HALF_FP_CONFIG: |
| return Info.write<ol_device_fp_capability_flags_t>(0); |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT: |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE: |
| return Info.write<uint32_t>(1); |
| case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF: |
| return Info.write<uint32_t>(0); |
| case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY: |
| case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: |
| case OL_DEVICE_INFO_ADDRESS_BITS: |
| return Info.write<uint32_t>(std::numeric_limits<uintptr_t>::digits); |
| case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE: |
| case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: |
| return Info.write<uint64_t>(0); |
| default: |
| return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
| "getDeviceInfo enum '%i' is invalid", PropName); |
| } |
| |
| return Error::success(); |
| } |
| |
| Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, |
| size_t PropSize, void *PropValue) { |
| if (Device == OffloadContext::get().HostDevice()) |
| return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue, |
| nullptr); |
| return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, |
| nullptr); |
| } |
| |
| Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, |
| ol_device_info_t PropName, size_t *PropSizeRet) { |
| if (Device == OffloadContext::get().HostDevice()) |
| return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr, |
| PropSizeRet); |
| return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); |
| } |
| |
| Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { |
| for (auto &Platform : OffloadContext::get().Platforms) { |
| for (auto &Device : Platform.Devices) { |
| if (!Callback(&Device, UserData)) { |
| break; |
| } |
| } |
| } |
| |
| return Error::success(); |
| } |
| |
| TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) { |
| switch (Type) { |
| case OL_ALLOC_TYPE_DEVICE: |
| return TARGET_ALLOC_DEVICE; |
| case OL_ALLOC_TYPE_HOST: |
| return TARGET_ALLOC_HOST; |
| case OL_ALLOC_TYPE_MANAGED: |
| default: |
| return TARGET_ALLOC_SHARED; |
| } |
| } |
| |
| Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type, |
| size_t Size, void **AllocationOut) { |
| auto Alloc = |
| Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type)); |
| if (!Alloc) |
| return Alloc.takeError(); |
| |
| *AllocationOut = *Alloc; |
| OffloadContext::get().AllocInfoMap.insert_or_assign(*Alloc, |
| AllocInfo{Device, Type}); |
| return Error::success(); |
| } |
| |
| Error olMemFree_impl(void *Address) { |
| if (!OffloadContext::get().AllocInfoMap.contains(Address)) |
| return createOffloadError(ErrorCode::INVALID_ARGUMENT, |
| "address is not a known allocation"); |
| |
| auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address); |
| auto Device = AllocInfo.Device; |
| auto Type = AllocInfo.Type; |
| |
| if (auto Res = |
| Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type))) |
| return Res; |
| |
| OffloadContext::get().AllocInfoMap.erase(Address); |
| |
| return Error::success(); |
| } |
| |
| Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) { |
| auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device); |
| if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) |
| return Err; |
| |
| *Queue = CreatedQueue.release(); |
| return Error::success(); |
| } |
| |
| Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); } |
| |
| Error olSyncQueue_impl(ol_queue_handle_t Queue) { |
| // Host plugin doesn't have a queue set so it's not safe to call synchronize |
| // on it, but we have nothing to synchronize in that situation anyway. |
| if (Queue->AsyncInfo->Queue) { |
| // We don't need to release the queue and we would like the ability for |
| // other offload threads to submit work concurrently, so pass "false" here |
| // so we don't release the underlying queue object. |
| if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false)) |
| return Err; |
| } |
| |
| return Error::success(); |
| } |
| |
| Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events, |
| size_t NumEvents) { |
| auto *Device = Queue->Device->Device; |
| |
| for (size_t I = 0; I < NumEvents; I++) { |
| auto *Event = Events[I]; |
| |
| if (!Event) |
| return Plugin::error(ErrorCode::INVALID_NULL_HANDLE, |
| "olWaitEvents asked to wait on a NULL event"); |
| |
| // Do nothing if the event is for this queue or the event is always complete |
| if (Event->Queue == Queue || !Event->EventInfo) |
| continue; |
| |
| if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo)) |
| return Err; |
| } |
| |
| return Error::success(); |
| } |
| |
| Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue, |
| ol_queue_info_t PropName, size_t PropSize, |
| void *PropValue, size_t *PropSizeRet) { |
| InfoWriter Info(PropSize, PropValue, PropSizeRet); |
| |
| switch (PropName) { |
| case OL_QUEUE_INFO_DEVICE: |
| return Info.write<ol_device_handle_t>(Queue->Device); |
| case OL_QUEUE_INFO_EMPTY: { |
| auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo); |
| if (auto Err = Pending.takeError()) |
| return Err; |
| return Info.write<bool>(!*Pending); |
| } |
| default: |
| return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
| "olGetQueueInfo enum '%i' is invalid", PropName); |
| } |
| |
| return Error::success(); |
| } |
| |
| Error olGetQueueInfo_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName, |
| size_t PropSize, void *PropValue) { |
| return olGetQueueInfoImplDetail(Queue, PropName, PropSize, PropValue, |
| nullptr); |
| } |
| |
| Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName, |
| size_t *PropSizeRet) { |
| return olGetQueueInfoImplDetail(Queue, PropName, 0, nullptr, PropSizeRet); |
| } |
| |
| Error olSyncEvent_impl(ol_event_handle_t Event) { |
| if (!Event->EventInfo) |
| // Event always complete |
| return Plugin::success(); |
| |
| if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo)) |
| return Res; |
| |
| return Error::success(); |
| } |
| |
| Error olDestroyEvent_impl(ol_event_handle_t Event) { |
| if (Event->EventInfo) |
| if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo)) |
| return Res; |
| |
| return olDestroy(Event); |
| } |
| |
| Error olGetEventInfoImplDetail(ol_event_handle_t Event, |
| ol_event_info_t PropName, size_t PropSize, |
| void *PropValue, size_t *PropSizeRet) { |
| InfoWriter Info(PropSize, PropValue, PropSizeRet); |
| |
| switch (PropName) { |
| case OL_EVENT_INFO_QUEUE: |
| return Info.write<ol_queue_handle_t>(Event->Queue); |
| default: |
| return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
| "olGetEventInfo enum '%i' is invalid", PropName); |
| } |
| |
| return Error::success(); |
| } |
| |
| Error olGetEventInfo_impl(ol_event_handle_t Event, ol_event_info_t PropName, |
| size_t PropSize, void *PropValue) { |
| |
| return olGetEventInfoImplDetail(Event, PropName, PropSize, PropValue, |
| nullptr); |
| } |
| |
| Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName, |
| size_t *PropSizeRet) { |
| return olGetEventInfoImplDetail(Event, PropName, 0, nullptr, PropSizeRet); |
| } |
| |
| Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) { |
| auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo); |
| if (auto Err = Pending.takeError()) |
| return Err; |
| |
| *EventOut = new ol_event_impl_t(nullptr, Queue); |
| if (!*Pending) |
| // Queue is empty, don't record an event and consider the event always |
| // complete |
| return Plugin::success(); |
| |
| if (auto Res = Queue->Device->Device->createEvent(&(*EventOut)->EventInfo)) |
| return Res; |
| |
| if (auto Res = Queue->Device->Device->recordEvent((*EventOut)->EventInfo, |
| Queue->AsyncInfo)) |
| return Res; |
| |
| return Plugin::success(); |
| } |
| |
| Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, |
| ol_device_handle_t DstDevice, const void *SrcPtr, |
| ol_device_handle_t SrcDevice, size_t Size) { |
| auto Host = OffloadContext::get().HostDevice(); |
| if (DstDevice == Host && SrcDevice == Host) { |
| if (!Queue) { |
| std::memcpy(DstPtr, SrcPtr, Size); |
| return Error::success(); |
| } else { |
| return createOffloadError( |
| ErrorCode::INVALID_ARGUMENT, |
| "ane of DstDevice and SrcDevice must be a non-host device if " |
| "queue is specified"); |
| } |
| } |
| |
| // If no queue is given the memcpy will be synchronous |
| auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr; |
| |
| if (DstDevice == Host) { |
| if (auto Res = |
| SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl)) |
| return Res; |
| } else if (SrcDevice == Host) { |
| if (auto Res = |
| DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl)) |
| return Res; |
| } else { |
| if (auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device, |
| DstPtr, Size, QueueImpl)) |
| return Res; |
| } |
| |
| return Error::success(); |
| } |
| |
| Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData, |
| size_t ProgDataSize, ol_program_handle_t *Program) { |
| // Make a copy of the program binary in case it is released by the caller. |
| auto ImageData = MemoryBuffer::getMemBufferCopy( |
| StringRef(reinterpret_cast<const char *>(ProgData), ProgDataSize)); |
| |
| auto DeviceImage = __tgt_device_image{ |
| const_cast<char *>(ImageData->getBuffer().data()), |
| const_cast<char *>(ImageData->getBuffer().data()) + ProgDataSize, nullptr, |
| nullptr}; |
| |
| ol_program_handle_t Prog = |
| new ol_program_impl_t(nullptr, std::move(ImageData), DeviceImage); |
| |
| auto Res = |
| Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage); |
| if (!Res) { |
| delete Prog; |
| return Res.takeError(); |
| } |
| assert(*Res != nullptr && "loadBinary returned nullptr"); |
| |
| Prog->Image = *Res; |
| *Program = Prog; |
| |
| return Error::success(); |
| } |
| |
| Error olDestroyProgram_impl(ol_program_handle_t Program) { |
| auto &Device = Program->Image->getDevice(); |
| if (auto Err = Device.unloadBinary(Program->Image)) |
| return Err; |
| |
| auto &LoadedImages = Device.LoadedImages; |
| LoadedImages.erase( |
| std::find(LoadedImages.begin(), LoadedImages.end(), Program->Image)); |
| |
| return olDestroy(Program); |
| } |
| |
| Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, |
| ol_symbol_handle_t Kernel, const void *ArgumentsData, |
| size_t ArgumentsSize, |
| const ol_kernel_launch_size_args_t *LaunchSizeArgs) { |
| auto *DeviceImpl = Device->Device; |
| if (Queue && Device != Queue->Device) { |
| return createOffloadError( |
| ErrorCode::INVALID_DEVICE, |
| "device specified does not match the device of the given queue"); |
| } |
| |
| if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL) |
| return createOffloadError(ErrorCode::SYMBOL_KIND, |
| "provided symbol is not a kernel"); |
| |
| auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr; |
| AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl); |
| KernelArgsTy LaunchArgs{}; |
| LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroups.x; |
| LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroups.y; |
| LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroups.z; |
| LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSize.x; |
| LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSize.y; |
| LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSize.z; |
| LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory; |
| |
| KernelLaunchParamsTy Params; |
| Params.Data = const_cast<void *>(ArgumentsData); |
| Params.Size = ArgumentsSize; |
| LaunchArgs.ArgPtrs = reinterpret_cast<void **>(&Params); |
| // Don't do anything with pointer indirection; use arg data as-is |
| LaunchArgs.Flags.IsCUDA = true; |
| |
| auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl); |
| auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr, |
| LaunchArgs, AsyncInfoWrapper); |
| |
| AsyncInfoWrapper.finalize(Err); |
| if (Err) |
| return Err; |
| |
| return Error::success(); |
| } |
| |
| Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name, |
| ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) { |
| auto &Device = Program->Image->getDevice(); |
| |
| std::lock_guard<std::mutex> Lock(Program->SymbolListMutex); |
| |
| switch (Kind) { |
| case OL_SYMBOL_KIND_KERNEL: { |
| auto &Kernel = Program->KernelSymbols[Name]; |
| if (!Kernel) { |
| auto KernelImpl = Device.constructKernel(Name); |
| if (!KernelImpl) |
| return KernelImpl.takeError(); |
| |
| if (auto Err = KernelImpl->init(Device, *Program->Image)) |
| return Err; |
| |
| Kernel = std::make_unique<ol_symbol_impl_t>(KernelImpl->getName(), |
| &*KernelImpl); |
| } |
| |
| *Symbol = Kernel.get(); |
| return Error::success(); |
| } |
| case OL_SYMBOL_KIND_GLOBAL_VARIABLE: { |
| auto &Global = Program->GlobalSymbols[Name]; |
| if (!Global) { |
| GlobalTy GlobalObj{Name}; |
| if (auto Res = |
| Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice( |
| Device, *Program->Image, GlobalObj)) |
| return Res; |
| |
| Global = std::make_unique<ol_symbol_impl_t>(GlobalObj.getName().c_str(), |
| std::move(GlobalObj)); |
| } |
| |
| *Symbol = Global.get(); |
| return Error::success(); |
| } |
| default: |
| return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
| "getSymbol kind enum '%i' is invalid", Kind); |
| } |
| } |
| |
| Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol, |
| ol_symbol_info_t PropName, size_t PropSize, |
| void *PropValue, size_t *PropSizeRet) { |
| InfoWriter Info(PropSize, PropValue, PropSizeRet); |
| |
| auto CheckKind = [&](ol_symbol_kind_t Required) { |
| if (Symbol->Kind != Required) { |
| std::string ErrBuffer; |
| llvm::raw_string_ostream(ErrBuffer) |
| << PropName << ": Expected a symbol of Kind " << Required |
| << " but given a symbol of Kind " << Symbol->Kind; |
| return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str()); |
| } |
| return Plugin::success(); |
| }; |
| |
| switch (PropName) { |
| case OL_SYMBOL_INFO_KIND: |
| return Info.write<ol_symbol_kind_t>(Symbol->Kind); |
| case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS: |
| if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE)) |
| return Err; |
| return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr()); |
| case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE: |
| if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE)) |
| return Err; |
| return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize()); |
| default: |
| return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
| "olGetSymbolInfo enum '%i' is invalid", PropName); |
| } |
| |
| return Error::success(); |
| } |
| |
| Error olGetSymbolInfo_impl(ol_symbol_handle_t Symbol, ol_symbol_info_t PropName, |
| size_t PropSize, void *PropValue) { |
| |
| return olGetSymbolInfoImplDetail(Symbol, PropName, PropSize, PropValue, |
| nullptr); |
| } |
| |
| Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol, |
| ol_symbol_info_t PropName, size_t *PropSizeRet) { |
| return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet); |
| } |
| |
| Error olLaunchHostFunction_impl(ol_queue_handle_t Queue, |
| ol_host_function_cb_t Callback, |
| void *UserData) { |
| return Queue->Device->Device->enqueueHostCall(Callback, UserData, |
| Queue->AsyncInfo); |
| } |
| |
| } // namespace offload |
| } // namespace llvm |