| //===- RPC.h - Interface for remote procedure calls from the GPU ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "RPC.h" |
| |
| #include "Shared/Debug.h" |
| #include "Shared/RPCOpcodes.h" |
| |
| #include "PluginInterface.h" |
| |
| #include "shared/rpc.h" |
| #include "shared/rpc_opcodes.h" |
| #include "shared/rpc_server.h" |
| |
| using namespace llvm; |
| using namespace omp; |
| using namespace target; |
| |
| template <uint32_t NumLanes> |
| rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, |
| rpc::Server::Port &Port) { |
| |
| switch (Port.get_opcode()) { |
| case LIBC_MALLOC: { |
| Port.recv_and_send([&](rpc::Buffer *Buffer, uint32_t) { |
| Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate( |
| Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING)); |
| }); |
| break; |
| } |
| case LIBC_FREE: { |
| Port.recv([&](rpc::Buffer *Buffer, uint32_t) { |
| Device.free(reinterpret_cast<void *>(Buffer->data[0]), |
| TARGET_ALLOC_DEVICE_NON_BLOCKING); |
| }); |
| break; |
| } |
| case OFFLOAD_HOST_CALL: { |
| uint64_t Sizes[NumLanes] = {0}; |
| unsigned long long Results[NumLanes] = {0}; |
| void *Args[NumLanes] = {nullptr}; |
| Port.recv_n(Args, Sizes, [&](uint64_t Size) { return new char[Size]; }); |
| Port.recv([&](rpc::Buffer *buffer, uint32_t ID) { |
| using FuncPtrTy = unsigned long long (*)(void *); |
| auto Func = reinterpret_cast<FuncPtrTy>(buffer->data[0]); |
| Results[ID] = Func(Args[ID]); |
| }); |
| Port.send([&](rpc::Buffer *Buffer, uint32_t ID) { |
| Buffer->data[0] = static_cast<uint64_t>(Results[ID]); |
| delete[] reinterpret_cast<char *>(Args[ID]); |
| }); |
| break; |
| } |
| default: |
| return rpc::RPC_UNHANDLED_OPCODE; |
| break; |
| } |
| return rpc::RPC_SUCCESS; |
| } |
| |
| static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, |
| rpc::Server::Port &Port, |
| uint32_t NumLanes) { |
| if (NumLanes == 1) |
| return handleOffloadOpcodes<1>(Device, Port); |
| else if (NumLanes == 32) |
| return handleOffloadOpcodes<32>(Device, Port); |
| else if (NumLanes == 64) |
| return handleOffloadOpcodes<64>(Device, Port); |
| else |
| return rpc::RPC_ERROR; |
| } |
| |
| static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) { |
| uint64_t NumPorts = |
| std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); |
| rpc::Server Server(NumPorts, Buffer); |
| |
| auto Port = Server.try_open(Device.getWarpSize()); |
| if (!Port) |
| return rpc::RPC_SUCCESS; |
| |
| rpc::Status Status = |
| handleOffloadOpcodes(Device, *Port, Device.getWarpSize()); |
| |
| // Let the `libc` library handle any other unhandled opcodes. |
| if (Status == rpc::RPC_UNHANDLED_OPCODE) |
| Status = LIBC_NAMESPACE::shared::handle_libc_opcodes(*Port, |
| Device.getWarpSize()); |
| |
| Port->close(); |
| |
| return Status; |
| } |
| |
| void RPCServerTy::ServerThread::startThread() { |
| if (!Running.fetch_or(true, std::memory_order_acquire)) |
| Worker = std::thread([this]() { run(); }); |
| } |
| |
| void RPCServerTy::ServerThread::shutDown() { |
| if (!Running.fetch_and(false, std::memory_order_release)) |
| return; |
| { |
| std::lock_guard<decltype(Mutex)> Lock(Mutex); |
| CV.notify_all(); |
| } |
| if (Worker.joinable()) |
| Worker.join(); |
| } |
| |
| void RPCServerTy::ServerThread::run() { |
| std::unique_lock<decltype(Mutex)> Lock(Mutex); |
| for (;;) { |
| CV.wait(Lock, [&]() { |
| return NumUsers.load(std::memory_order_acquire) > 0 || |
| !Running.load(std::memory_order_acquire); |
| }); |
| |
| if (!Running.load(std::memory_order_acquire)) |
| return; |
| |
| Lock.unlock(); |
| while (NumUsers.load(std::memory_order_relaxed) > 0 && |
| Running.load(std::memory_order_relaxed)) { |
| std::lock_guard<decltype(Mutex)> Lock(BufferMutex); |
| for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) { |
| if (!Buffer || !Device) |
| continue; |
| |
| // If running the server failed, print a message but keep running. |
| if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS) |
| FAILURE_MESSAGE("Unhandled or invalid RPC opcode!"); |
| } |
| } |
| Lock.lock(); |
| } |
| } |
| |
| RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin) |
| : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())), |
| Devices(std::make_unique<plugin::GenericDeviceTy *[]>( |
| Plugin.getNumDevices())), |
| Thread(new ServerThread(Buffers.get(), Devices.get(), |
| Plugin.getNumDevices(), BufferMutex)) {} |
| |
| llvm::Error RPCServerTy::startThread() { |
| Thread->startThread(); |
| return Error::success(); |
| } |
| |
| llvm::Error RPCServerTy::shutDown() { |
| Thread->shutDown(); |
| return Error::success(); |
| } |
| |
| llvm::Expected<bool> |
| RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device, |
| plugin::GenericGlobalHandlerTy &Handler, |
| plugin::DeviceImageTy &Image) { |
| return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client"); |
| } |
| |
| Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, |
| plugin::GenericGlobalHandlerTy &Handler, |
| plugin::DeviceImageTy &Image) { |
| uint64_t NumPorts = |
| std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); |
| void *RPCBuffer = Device.allocate( |
| rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr, |
| TARGET_ALLOC_HOST); |
| if (!RPCBuffer) |
| return plugin::Plugin::error( |
| error::ErrorCode::UNKNOWN, |
| "failed to initialize RPC server for device %d", Device.getDeviceId()); |
| |
| // Get the address of the RPC client from the device. |
| plugin::GlobalTy ClientGlobal("__llvm_rpc_client", sizeof(rpc::Client)); |
| if (auto Err = |
| Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal)) |
| return Err; |
| |
| rpc::Client client(NumPorts, RPCBuffer); |
| if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client, |
| sizeof(rpc::Client), nullptr)) |
| return Err; |
| std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex); |
| Buffers[Device.getDeviceId()] = RPCBuffer; |
| Devices[Device.getDeviceId()] = &Device; |
| |
| return Error::success(); |
| } |
| |
| Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { |
| std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex); |
| Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST); |
| Buffers[Device.getDeviceId()] = nullptr; |
| Devices[Device.getDeviceId()] = nullptr; |
| return Error::success(); |
| } |