| //===----------------- Client.cpp - Client Implementation -----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // gRPC (Client) for the remote plugin. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <cmath> |
| |
| #include "Client.h" |
| #include "omptarget.h" |
| #include "openmp.pb.h" |
| |
| using namespace std::chrono; |
| |
| using grpc::ClientContext; |
| using grpc::ClientReader; |
| using grpc::ClientWriter; |
| using grpc::Status; |
| |
| template <typename Fn1, typename Fn2, typename TReturn> |
| auto RemoteOffloadClient::remoteCall(Fn1 Preprocess, Fn2 Postprocess, |
| TReturn ErrorValue, bool Timeout) { |
| ArenaAllocatorLock->lock(); |
| if (Arena->SpaceAllocated() >= MaxSize) |
| Arena->Reset(); |
| ArenaAllocatorLock->unlock(); |
| |
| ClientContext Context; |
| if (Timeout) { |
| auto Deadline = |
| std::chrono::system_clock::now() + std::chrono::seconds(Timeout); |
| Context.set_deadline(Deadline); |
| } |
| |
| Status RPCStatus; |
| auto Reply = Preprocess(RPCStatus, Context); |
| |
| // TODO: Error handle more appropriately |
| if (!RPCStatus.ok()) { |
| CLIENT_DBG("%s", RPCStatus.error_message().c_str()); |
| } else { |
| return Postprocess(Reply); |
| } |
| |
| CLIENT_DBG("Failed"); |
| return ErrorValue; |
| } |
| |
| int32_t RemoteOffloadClient::shutdown(void) { |
| ClientContext Context; |
| Null Request; |
| I32 Reply; |
| CLIENT_DBG("Shutting down server."); |
| auto Status = Stub->Shutdown(&Context, Request, &Reply); |
| if (Status.ok()) |
| return Reply.number(); |
| return 1; |
| } |
| |
| int32_t RemoteOffloadClient::registerLib(__tgt_bin_desc *Desc) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Request = protobuf::Arena::CreateMessage<TargetBinaryDescription>( |
| Arena.get()); |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| loadTargetBinaryDescription(Desc, *Request); |
| Request->set_bin_ptr((uint64_t)Desc); |
| |
| CLIENT_DBG("Registering library"); |
| RPCStatus = Stub->RegisterLib(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](const auto &Reply) { |
| if (Reply->number() == 0) { |
| CLIENT_DBG("Registered library"); |
| return 0; |
| } |
| return 1; |
| }, |
| /* Error Value */ 1); |
| } |
| |
| int32_t RemoteOffloadClient::unregisterLib(__tgt_bin_desc *Desc) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Request = protobuf::Arena::CreateMessage<Pointer>(Arena.get()); |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| |
| Request->set_number((uint64_t)Desc); |
| |
| CLIENT_DBG("Unregistering library"); |
| RPCStatus = Stub->UnregisterLib(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](const auto &Reply) { |
| if (Reply->number() == 0) { |
| CLIENT_DBG("Unregistered library"); |
| return 0; |
| } |
| CLIENT_DBG("Failed to unregister library"); |
| return 1; |
| }, |
| /* Error Value */ 1); |
| } |
| |
| int32_t RemoteOffloadClient::isValidBinary(__tgt_device_image *Image) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Request = |
| protobuf::Arena::CreateMessage<TargetDeviceImagePtr>(Arena.get()); |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| |
| Request->set_image_ptr((uint64_t)Image->ImageStart); |
| |
| auto *EntryItr = Image->EntriesBegin; |
| while (EntryItr != Image->EntriesEnd) |
| Request->add_entry_ptrs((uint64_t)EntryItr++); |
| |
| CLIENT_DBG("Validating binary"); |
| RPCStatus = Stub->IsValidBinary(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](const auto &Reply) { |
| if (Reply->number()) { |
| CLIENT_DBG("Validated binary"); |
| } else { |
| CLIENT_DBG("Could not validate binary"); |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ 0); |
| } |
| |
| int32_t RemoteOffloadClient::getNumberOfDevices() { |
| return remoteCall( |
| /* Preprocess */ |
| [&](Status &RPCStatus, ClientContext &Context) { |
| auto *Request = protobuf::Arena::CreateMessage<Null>(Arena.get()); |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| |
| CLIENT_DBG("Getting number of devices"); |
| RPCStatus = Stub->GetNumberOfDevices(&Context, *Request, Reply); |
| |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](const auto &Reply) { |
| if (Reply->number()) { |
| CLIENT_DBG("Found %d devices", Reply->number()); |
| } else { |
| CLIENT_DBG("Could not get the number of devices"); |
| } |
| return Reply->number(); |
| }, |
| /*Error Value*/ -1); |
| } |
| |
| int32_t RemoteOffloadClient::initDevice(int32_t DeviceId) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Request = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| |
| Request->set_number(DeviceId); |
| |
| CLIENT_DBG("Initializing device %d", DeviceId); |
| RPCStatus = Stub->InitDevice(&Context, *Request, Reply); |
| |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](const auto &Reply) { |
| if (!Reply->number()) { |
| CLIENT_DBG("Initialized device %d", DeviceId); |
| } else { |
| CLIENT_DBG("Could not initialize device %d", DeviceId); |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1); |
| } |
| |
| int32_t RemoteOffloadClient::initRequires(int64_t RequiresFlags) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Request = protobuf::Arena::CreateMessage<I64>(Arena.get()); |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| Request->set_number(RequiresFlags); |
| CLIENT_DBG("Initializing requires"); |
| RPCStatus = Stub->InitRequires(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](const auto &Reply) { |
| if (Reply->number()) { |
| CLIENT_DBG("Initialized requires"); |
| } else { |
| CLIENT_DBG("Could not initialize requires"); |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1); |
| } |
| |
| __tgt_target_table *RemoteOffloadClient::loadBinary(int32_t DeviceId, |
| __tgt_device_image *Image) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *ImageMessage = |
| protobuf::Arena::CreateMessage<Binary>(Arena.get()); |
| auto *Reply = protobuf::Arena::CreateMessage<TargetTable>(Arena.get()); |
| ImageMessage->set_image_ptr((uint64_t)Image->ImageStart); |
| ImageMessage->set_device_id(DeviceId); |
| |
| CLIENT_DBG("Loading Image %p to device %d", Image, DeviceId); |
| RPCStatus = Stub->LoadBinary(&Context, *ImageMessage, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (Reply->entries_size() == 0) { |
| CLIENT_DBG("Could not load image %p onto device %d", Image, DeviceId); |
| return (__tgt_target_table *)nullptr; |
| } |
| DevicesToTables[DeviceId] = std::make_unique<__tgt_target_table>(); |
| unloadTargetTable(*Reply, DevicesToTables[DeviceId].get(), |
| RemoteEntries[DeviceId]); |
| |
| CLIENT_DBG("Loaded Image %p to device %d with %d entries", Image, |
| DeviceId, Reply->entries_size()); |
| |
| return DevicesToTables[DeviceId].get(); |
| }, |
| /* Error Value */ (__tgt_target_table *)nullptr, |
| /* Timeout */ false); |
| } |
| |
| int64_t RemoteOffloadClient::synchronize(int32_t DeviceId, |
| __tgt_async_info *AsyncInfo) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| auto *Info = |
| protobuf::Arena::CreateMessage<SynchronizeDevice>(Arena.get()); |
| |
| Info->set_device_id(DeviceId); |
| Info->set_queue_ptr((uint64_t)AsyncInfo); |
| |
| CLIENT_DBG("Synchronizing device %d", DeviceId); |
| RPCStatus = Stub->Synchronize(&Context, *Info, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (Reply->number()) { |
| CLIENT_DBG("Synchronized device %d", DeviceId); |
| } else { |
| CLIENT_DBG("Could not synchronize device %d", DeviceId); |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1); |
| } |
| |
| int32_t RemoteOffloadClient::isDataExchangeable(int32_t SrcDevId, |
| int32_t DstDevId) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Request = protobuf::Arena::CreateMessage<DevicePair>(Arena.get()); |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| |
| Request->set_src_dev_id(SrcDevId); |
| Request->set_dst_dev_id(DstDevId); |
| |
| CLIENT_DBG("Asking if data is exchangeable between %d, %d", SrcDevId, |
| DstDevId); |
| RPCStatus = Stub->IsDataExchangeable(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (Reply->number()) { |
| CLIENT_DBG("Data is exchangeable between %d, %d", SrcDevId, DstDevId); |
| } else { |
| CLIENT_DBG("Data is not exchangeable between %d, %d", SrcDevId, |
| DstDevId); |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1); |
| } |
| |
| void *RemoteOffloadClient::dataAlloc(int32_t DeviceId, int64_t Size, |
| void *HstPtr) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Reply = protobuf::Arena::CreateMessage<Pointer>(Arena.get()); |
| auto *Request = protobuf::Arena::CreateMessage<AllocData>(Arena.get()); |
| |
| Request->set_device_id(DeviceId); |
| Request->set_size(Size); |
| Request->set_hst_ptr((uint64_t)HstPtr); |
| |
| CLIENT_DBG("Allocating %ld bytes on device %d", Size, DeviceId); |
| RPCStatus = Stub->DataAlloc(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (Reply->number()) { |
| CLIENT_DBG("Allocated %ld bytes on device %d at %p", Size, DeviceId, |
| (void *)Reply->number()); |
| } else { |
| CLIENT_DBG("Could not allocate %ld bytes on device %d at %p", Size, |
| DeviceId, (void *)Reply->number()); |
| } |
| return (void *)Reply->number(); |
| }, |
| /* Error Value */ (void *)nullptr); |
| } |
| |
| int32_t RemoteOffloadClient::dataSubmitAsync(int32_t DeviceId, void *TgtPtr, |
| void *HstPtr, int64_t Size, |
| __tgt_async_info *AsyncInfo) { |
| |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| std::unique_ptr<ClientWriter<SubmitDataAsync>> Writer( |
| Stub->DataSubmitAsync(&Context, Reply)); |
| |
| if (Size > BlockSize) { |
| int64_t Start = 0, End = BlockSize; |
| for (auto I = 0; I < ceil((float)Size / BlockSize); I++) { |
| auto *Request = |
| protobuf::Arena::CreateMessage<SubmitDataAsync>(Arena.get()); |
| |
| Request->set_device_id(DeviceId); |
| Request->set_data((char *)HstPtr + Start, End - Start); |
| Request->set_hst_ptr((uint64_t)HstPtr); |
| Request->set_tgt_ptr((uint64_t)TgtPtr); |
| Request->set_start(Start); |
| Request->set_size(Size); |
| Request->set_queue_ptr((uint64_t)AsyncInfo); |
| |
| CLIENT_DBG("Submitting %ld-%ld/%ld bytes async on device %d at %p", |
| Start, End, Size, DeviceId, TgtPtr) |
| |
| if (!Writer->Write(*Request)) { |
| CLIENT_DBG("Broken stream when submitting data"); |
| Reply->set_number(0); |
| return Reply; |
| } |
| |
| Start += BlockSize; |
| End += BlockSize; |
| if (End >= Size) |
| End = Size; |
| } |
| } else { |
| auto *Request = |
| protobuf::Arena::CreateMessage<SubmitDataAsync>(Arena.get()); |
| |
| Request->set_device_id(DeviceId); |
| Request->set_data(HstPtr, Size); |
| Request->set_hst_ptr((uint64_t)HstPtr); |
| Request->set_tgt_ptr((uint64_t)TgtPtr); |
| Request->set_start(0); |
| Request->set_size(Size); |
| |
| CLIENT_DBG("Submitting %ld bytes async on device %d at %p", Size, |
| DeviceId, TgtPtr) |
| if (!Writer->Write(*Request)) { |
| CLIENT_DBG("Broken stream when submitting data"); |
| Reply->set_number(0); |
| return Reply; |
| } |
| } |
| |
| Writer->WritesDone(); |
| RPCStatus = Writer->Finish(); |
| |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (!Reply->number()) { |
| CLIENT_DBG("Async submitted %ld bytes on device %d at %p", Size, |
| DeviceId, TgtPtr) |
| } else { |
| CLIENT_DBG("Could not async submit %ld bytes on device %d at %p", |
| Size, DeviceId, TgtPtr) |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1, |
| /* Timeout */ false); |
| } |
| |
| int32_t RemoteOffloadClient::dataRetrieveAsync(int32_t DeviceId, void *HstPtr, |
| void *TgtPtr, int64_t Size, |
| __tgt_async_info *AsyncInfo) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Request = |
| protobuf::Arena::CreateMessage<RetrieveDataAsync>(Arena.get()); |
| |
| Request->set_device_id(DeviceId); |
| Request->set_size(Size); |
| Request->set_hst_ptr((int64_t)HstPtr); |
| Request->set_tgt_ptr((int64_t)TgtPtr); |
| Request->set_queue_ptr((uint64_t)AsyncInfo); |
| |
| auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get()); |
| std::unique_ptr<ClientReader<Data>> Reader( |
| Stub->DataRetrieveAsync(&Context, *Request)); |
| Reader->WaitForInitialMetadata(); |
| while (Reader->Read(Reply)) { |
| if (Reply->ret()) { |
| CLIENT_DBG("Could not async retrieve %ld bytes on device %d at %p " |
| "for %p", |
| Size, DeviceId, TgtPtr, HstPtr) |
| return Reply; |
| } |
| |
| if (Reply->start() == 0 && Reply->size() == Reply->data().size()) { |
| CLIENT_DBG("Async retrieving %ld bytes on device %d at %p for %p", |
| Size, DeviceId, TgtPtr, HstPtr) |
| |
| memcpy(HstPtr, Reply->data().data(), Reply->data().size()); |
| |
| return Reply; |
| } |
| CLIENT_DBG("Retrieving %lu-%lu/%lu bytes async from (%p) to (%p) " |
| "on Device %d", |
| Reply->start(), Reply->start() + Reply->data().size(), |
| Reply->size(), (void *)Request->tgt_ptr(), HstPtr, |
| Request->device_id()); |
| |
| memcpy((void *)((char *)HstPtr + Reply->start()), |
| Reply->data().data(), Reply->data().size()); |
| } |
| RPCStatus = Reader->Finish(); |
| |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (!Reply->ret()) { |
| CLIENT_DBG("Async retrieve %ld bytes on Device %d", Size, DeviceId); |
| } else { |
| CLIENT_DBG("Could not async retrieve %ld bytes on Device %d", Size, |
| DeviceId); |
| } |
| return Reply->ret(); |
| }, |
| /* Error Value */ -1, |
| /* Timeout */ false); |
| } |
| |
| int32_t RemoteOffloadClient::dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, |
| int32_t DstDevId, void *DstPtr, |
| int64_t Size, |
| __tgt_async_info *AsyncInfo) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| auto *Request = |
| protobuf::Arena::CreateMessage<ExchangeDataAsync>(Arena.get()); |
| |
| Request->set_src_dev_id(SrcDevId); |
| Request->set_src_ptr((uint64_t)SrcPtr); |
| Request->set_dst_dev_id(DstDevId); |
| Request->set_dst_ptr((uint64_t)DstPtr); |
| Request->set_size(Size); |
| Request->set_queue_ptr((uint64_t)AsyncInfo); |
| |
| CLIENT_DBG( |
| "Exchanging %ld bytes on device %d at %p for %p on device %d", Size, |
| SrcDevId, SrcPtr, DstPtr, DstDevId); |
| RPCStatus = Stub->DataExchangeAsync(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (Reply->number()) { |
| CLIENT_DBG( |
| "Exchanged %ld bytes on device %d at %p for %p on device %d", |
| Size, SrcDevId, SrcPtr, DstPtr, DstDevId); |
| } else { |
| CLIENT_DBG("Could not exchange %ld bytes on device %d at %p for %p " |
| "on device %d", |
| Size, SrcDevId, SrcPtr, DstPtr, DstDevId); |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1); |
| } |
| |
| int32_t RemoteOffloadClient::dataDelete(int32_t DeviceId, void *TgtPtr) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| auto *Request = protobuf::Arena::CreateMessage<DeleteData>(Arena.get()); |
| |
| Request->set_device_id(DeviceId); |
| Request->set_tgt_ptr((uint64_t)TgtPtr); |
| |
| CLIENT_DBG("Deleting data at %p on device %d", TgtPtr, DeviceId) |
| RPCStatus = Stub->DataDelete(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (!Reply->number()) { |
| CLIENT_DBG("Deleted data at %p on device %d", TgtPtr, DeviceId) |
| } else { |
| CLIENT_DBG("Could not delete data at %p on device %d", TgtPtr, |
| DeviceId) |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1); |
| } |
| |
| int32_t RemoteOffloadClient::runTargetRegionAsync( |
| int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, |
| int32_t ArgNum, __tgt_async_info *AsyncInfo) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| auto *Request = |
| protobuf::Arena::CreateMessage<TargetRegionAsync>(Arena.get()); |
| |
| Request->set_device_id(DeviceId); |
| Request->set_queue_ptr((uint64_t)AsyncInfo); |
| |
| Request->set_tgt_entry_ptr( |
| (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]); |
| |
| char **ArgPtr = (char **)TgtArgs; |
| for (auto I = 0; I < ArgNum; I++, ArgPtr++) |
| Request->add_tgt_args((uint64_t)*ArgPtr); |
| |
| char *OffsetPtr = (char *)TgtOffsets; |
| for (auto I = 0; I < ArgNum; I++, OffsetPtr++) |
| Request->add_tgt_offsets((uint64_t)*OffsetPtr); |
| |
| Request->set_arg_num(ArgNum); |
| |
| CLIENT_DBG("Running target region async on device %d", DeviceId); |
| RPCStatus = Stub->RunTargetRegionAsync(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (!Reply->number()) { |
| CLIENT_DBG("Ran target region async on device %d", DeviceId); |
| } else { |
| CLIENT_DBG("Could not run target region async on device %d", |
| DeviceId); |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1, |
| /* Timeout */ false); |
| } |
| |
| int32_t RemoteOffloadClient::runTargetTeamRegionAsync( |
| int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, |
| int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit, |
| uint64_t LoopTripcount, __tgt_async_info *AsyncInfo) { |
| return remoteCall( |
| /* Preprocess */ |
| [&](auto &RPCStatus, auto &Context) { |
| auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get()); |
| auto *Request = |
| protobuf::Arena::CreateMessage<TargetTeamRegionAsync>(Arena.get()); |
| |
| Request->set_device_id(DeviceId); |
| Request->set_queue_ptr((uint64_t)AsyncInfo); |
| |
| Request->set_tgt_entry_ptr( |
| (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]); |
| |
| char **ArgPtr = (char **)TgtArgs; |
| for (auto I = 0; I < ArgNum; I++, ArgPtr++) { |
| Request->add_tgt_args((uint64_t)*ArgPtr); |
| } |
| |
| char *OffsetPtr = (char *)TgtOffsets; |
| for (auto I = 0; I < ArgNum; I++, OffsetPtr++) |
| Request->add_tgt_offsets((uint64_t)*OffsetPtr); |
| |
| Request->set_arg_num(ArgNum); |
| Request->set_team_num(TeamNum); |
| Request->set_thread_limit(ThreadLimit); |
| Request->set_loop_tripcount(LoopTripcount); |
| |
| CLIENT_DBG("Running target team region async on device %d", DeviceId); |
| RPCStatus = Stub->RunTargetTeamRegionAsync(&Context, *Request, Reply); |
| return Reply; |
| }, |
| /* Postprocess */ |
| [&](auto &Reply) { |
| if (!Reply->number()) { |
| CLIENT_DBG("Ran target team region async on device %d", DeviceId); |
| } else { |
| CLIENT_DBG("Could not run target team region async on device %d", |
| DeviceId); |
| } |
| return Reply->number(); |
| }, |
| /* Error Value */ -1, |
| /* Timeout */ false); |
| } |
| |
| // TODO: Better error handling for the next three functions |
| int32_t RemoteClientManager::shutdown(void) { |
| int32_t Ret = 0; |
| for (auto &Client : Clients) |
| Ret &= Client.shutdown(); |
| return Ret; |
| } |
| |
| int32_t RemoteClientManager::registerLib(__tgt_bin_desc *Desc) { |
| int32_t Ret = 0; |
| for (auto &Client : Clients) |
| Ret &= Client.registerLib(Desc); |
| return Ret; |
| } |
| |
| int32_t RemoteClientManager::unregisterLib(__tgt_bin_desc *Desc) { |
| int32_t Ret = 0; |
| for (auto &Client : Clients) |
| Ret &= Client.unregisterLib(Desc); |
| return Ret; |
| } |
| |
| int32_t RemoteClientManager::isValidBinary(__tgt_device_image *Image) { |
| int32_t ClientIdx = 0; |
| for (auto &Client : Clients) { |
| if (auto Ret = Client.isValidBinary(Image)) |
| return Ret; |
| ClientIdx++; |
| } |
| return 0; |
| } |
| |
| int32_t RemoteClientManager::getNumberOfDevices() { |
| auto ClientIdx = 0; |
| for (auto &Client : Clients) { |
| if (auto NumDevices = Client.getNumberOfDevices()) { |
| Devices.push_back(NumDevices); |
| } |
| ClientIdx++; |
| } |
| |
| return std::accumulate(Devices.begin(), Devices.end(), 0); |
| } |
| |
| std::pair<int32_t, int32_t> RemoteClientManager::mapDeviceId(int32_t DeviceId) { |
| for (size_t ClientIdx = 0; ClientIdx < Devices.size(); ClientIdx++) { |
| if (!(DeviceId >= Devices[ClientIdx])) |
| return {ClientIdx, DeviceId}; |
| DeviceId -= Devices[ClientIdx]; |
| } |
| return {-1, -1}; |
| } |
| |
| int32_t RemoteClientManager::initDevice(int32_t DeviceId) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].initDevice(DeviceIdx); |
| } |
| |
| int32_t RemoteClientManager::initRequires(int64_t RequiresFlags) { |
| for (auto &Client : Clients) |
| Client.initRequires(RequiresFlags); |
| |
| return RequiresFlags; |
| } |
| |
| __tgt_target_table *RemoteClientManager::loadBinary(int32_t DeviceId, |
| __tgt_device_image *Image) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].loadBinary(DeviceIdx, Image); |
| } |
| |
| int64_t RemoteClientManager::synchronize(int32_t DeviceId, |
| __tgt_async_info *AsyncInfo) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].synchronize(DeviceIdx, AsyncInfo); |
| } |
| |
| int32_t RemoteClientManager::isDataExchangeable(int32_t SrcDevId, |
| int32_t DstDevId) { |
| int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx; |
| std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId); |
| std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId); |
| return Clients[SrcClientIdx].isDataExchangeable(SrcDeviceIdx, DstDeviceIdx); |
| } |
| |
| void *RemoteClientManager::dataAlloc(int32_t DeviceId, int64_t Size, |
| void *HstPtr) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].dataAlloc(DeviceIdx, Size, HstPtr); |
| } |
| |
| int32_t RemoteClientManager::dataDelete(int32_t DeviceId, void *TgtPtr) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].dataDelete(DeviceIdx, TgtPtr); |
| } |
| |
| int32_t RemoteClientManager::dataSubmitAsync(int32_t DeviceId, void *TgtPtr, |
| void *HstPtr, int64_t Size, |
| __tgt_async_info *AsyncInfo) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].dataSubmitAsync(DeviceIdx, TgtPtr, HstPtr, Size, |
| AsyncInfo); |
| } |
| |
| int32_t RemoteClientManager::dataRetrieveAsync(int32_t DeviceId, void *HstPtr, |
| void *TgtPtr, int64_t Size, |
| __tgt_async_info *AsyncInfo) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].dataRetrieveAsync(DeviceIdx, HstPtr, TgtPtr, Size, |
| AsyncInfo); |
| } |
| |
| int32_t RemoteClientManager::dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, |
| int32_t DstDevId, void *DstPtr, |
| int64_t Size, |
| __tgt_async_info *AsyncInfo) { |
| int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx; |
| std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId); |
| std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId); |
| return Clients[SrcClientIdx].dataExchangeAsync( |
| SrcDeviceIdx, SrcPtr, DstDeviceIdx, DstPtr, Size, AsyncInfo); |
| } |
| |
| int32_t RemoteClientManager::runTargetRegionAsync( |
| int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, |
| int32_t ArgNum, __tgt_async_info *AsyncInfo) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].runTargetRegionAsync( |
| DeviceIdx, TgtEntryPtr, TgtArgs, TgtOffsets, ArgNum, AsyncInfo); |
| } |
| |
| int32_t RemoteClientManager::runTargetTeamRegionAsync( |
| int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, |
| int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit, |
| uint64_t LoopTripCount, __tgt_async_info *AsyncInfo) { |
| int32_t ClientIdx, DeviceIdx; |
| std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); |
| return Clients[ClientIdx].runTargetTeamRegionAsync( |
| DeviceIdx, TgtEntryPtr, TgtArgs, TgtOffsets, ArgNum, TeamNum, ThreadLimit, |
| LoopTripCount, AsyncInfo); |
| } |