| //===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Implements wrappers around the sycl runtime library with C linkage |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <CL/sycl.hpp> |
| #include <level_zero/ze_api.h> |
| #include <sycl/ext/oneapi/backend/level_zero.hpp> |
| |
| #ifdef _WIN32 |
| #define SYCL_RUNTIME_EXPORT __declspec(dllexport) |
| #else |
| #define SYCL_RUNTIME_EXPORT |
| #endif // _WIN32 |
| |
| namespace { |
| |
| template <typename F> |
| auto catchAll(F &&func) { |
| try { |
| return func(); |
| } catch (const std::exception &e) { |
| fprintf(stdout, "An exception was thrown: %s\n", e.what()); |
| fflush(stdout); |
| abort(); |
| } catch (...) { |
| fprintf(stdout, "An unknown exception was thrown\n"); |
| fflush(stdout); |
| abort(); |
| } |
| } |
| |
| #define L0_SAFE_CALL(call) \ |
| { \ |
| ze_result_t status = (call); \ |
| if (status != ZE_RESULT_SUCCESS) { \ |
| fprintf(stdout, "L0 error %d\n", status); \ |
| fflush(stdout); \ |
| abort(); \ |
| } \ |
| } |
| |
| } // namespace |
| |
| static sycl::device getDefaultDevice() { |
| static sycl::device syclDevice; |
| static bool isDeviceInitialised = false; |
| if (!isDeviceInitialised) { |
| auto platformList = sycl::platform::get_platforms(); |
| for (const auto &platform : platformList) { |
| auto platformName = platform.get_info<sycl::info::platform::name>(); |
| bool isLevelZero = platformName.find("Level-Zero") != std::string::npos; |
| if (!isLevelZero) |
| continue; |
| |
| syclDevice = platform.get_devices()[0]; |
| isDeviceInitialised = true; |
| return syclDevice; |
| } |
| throw std::runtime_error("getDefaultDevice failed"); |
| } else |
| return syclDevice; |
| } |
| |
| static sycl::context getDefaultContext() { |
| static sycl::context syclContext{getDefaultDevice()}; |
| return syclContext; |
| } |
| |
| static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) { |
| void *memPtr = nullptr; |
| if (isShared) { |
| memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(), |
| getDefaultContext()); |
| } else { |
| memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(), |
| getDefaultContext()); |
| } |
| if (memPtr == nullptr) { |
| throw std::runtime_error("mem allocation failed!"); |
| } |
| return memPtr; |
| } |
| |
| static void deallocDeviceMemory(sycl::queue *queue, void *ptr) { |
| sycl::free(ptr, *queue); |
| } |
| |
| static ze_module_handle_t loadModule(const void *data, size_t dataSize) { |
| assert(data); |
| ze_module_handle_t zeModule; |
| ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC, |
| nullptr, |
| ZE_MODULE_FORMAT_IL_SPIRV, |
| dataSize, |
| (const uint8_t *)data, |
| nullptr, |
| nullptr}; |
| auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>( |
| getDefaultDevice()); |
| auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>( |
| getDefaultContext()); |
| L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr)); |
| return zeModule; |
| } |
| |
| static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) { |
| assert(zeModule); |
| assert(name); |
| ze_kernel_handle_t zeKernel; |
| ze_kernel_desc_t desc = {}; |
| desc.pKernelName = name; |
| |
| L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel)); |
| sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle = |
| sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero, |
| sycl::bundle_state::executable>( |
| {zeModule}, getDefaultContext()); |
| |
| auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>( |
| {kernelBundle, zeKernel}, getDefaultContext()); |
| return new sycl::kernel(kernel); |
| } |
| |
| static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, |
| size_t gridY, size_t gridZ, size_t blockX, |
| size_t blockY, size_t blockZ, size_t sharedMemBytes, |
| void **params, size_t paramsCount) { |
| auto syclGlobalRange = |
| sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX); |
| auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX); |
| sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange); |
| |
| queue->submit([&](sycl::handler &cgh) { |
| for (size_t i = 0; i < paramsCount; i++) { |
| cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i]))); |
| } |
| cgh.parallel_for(syclNdRange, *kernel); |
| }); |
| } |
| |
| // Wrappers |
| |
| extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() { |
| |
| return catchAll([&]() { |
| sycl::queue *queue = |
| new sycl::queue(getDefaultContext(), getDefaultDevice()); |
| return queue; |
| }); |
| } |
| |
| extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) { |
| catchAll([&]() { delete queue; }); |
| } |
| |
| extern "C" SYCL_RUNTIME_EXPORT void * |
| mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) { |
| return catchAll([&]() { |
| return allocDeviceMemory(queue, static_cast<size_t>(size), true); |
| }); |
| } |
| |
| extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) { |
| catchAll([&]() { |
| if (ptr) { |
| deallocDeviceMemory(queue, ptr); |
| } |
| }); |
| } |
| |
| extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t |
| mgpuModuleLoad(const void *data, size_t gpuBlobSize) { |
| return catchAll([&]() { return loadModule(data, gpuBlobSize); }); |
| } |
| |
| extern "C" SYCL_RUNTIME_EXPORT sycl::kernel * |
| mgpuModuleGetFunction(ze_module_handle_t module, const char *name) { |
| return catchAll([&]() { return getKernel(module, name); }); |
| } |
| |
| extern "C" SYCL_RUNTIME_EXPORT void |
| mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, |
| size_t blockX, size_t blockY, size_t blockZ, |
| size_t sharedMemBytes, sycl::queue *queue, void **params, |
| void ** /*extra*/, size_t paramsCount) { |
| return catchAll([&]() { |
| launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ, |
| sharedMemBytes, params, paramsCount); |
| }); |
| } |
| |
| extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) { |
| |
| catchAll([&]() { queue->wait(); }); |
| } |
| |
| extern "C" SYCL_RUNTIME_EXPORT void |
| mgpuModuleUnload(ze_module_handle_t module) { |
| |
| catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); }); |
| } |