//===------------------ Client.h - Client Implementation ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// gRPC Client for the remote plugin.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_REMOTE_SRC_CLIENT_H
#define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_REMOTE_SRC_CLIENT_H

#include "Utils.h"
#include "omptarget.h"
#include <google/protobuf/arena.h>
#include <grpcpp/grpcpp.h>
#include <grpcpp/security/credentials.h>
#include <grpcpp/support/channel_arguments.h>
#include <memory>
#include <mutex>
#include <numeric>

using grpc::Channel;
using openmp::libomptarget::remote::RemoteOffload;
using namespace RemoteOffloading;

using namespace google;

class RemoteOffloadClient {
  const int Timeout;

  int DebugLevel;
  uint64_t MaxSize;
  int64_t BlockSize;

  std::unique_ptr<RemoteOffload::Stub> Stub;
  std::unique_ptr<protobuf::Arena> Arena;

  std::unique_ptr<std::mutex> ArenaAllocatorLock;

  std::map<int32_t, std::unordered_map<void *, void *>> RemoteEntries;
  std::map<int32_t, std::unique_ptr<__tgt_target_table>> DevicesToTables;

  template <typename Fn1, typename Fn2, typename TReturn>
  auto remoteCall(Fn1 Preprocess, Fn2 Postprocess, TReturn ErrorValue,
                  bool Timeout = true);

public:
  RemoteOffloadClient(std::shared_ptr<Channel> Channel, int Timeout,
                      uint64_t MaxSize, int64_t BlockSize)
      : Timeout(Timeout), MaxSize(MaxSize), BlockSize(BlockSize),
        Stub(RemoteOffload::NewStub(Channel)) {
    DebugLevel = getDebugLevel();
    Arena = std::make_unique<protobuf::Arena>();
    ArenaAllocatorLock = std::make_unique<std::mutex>();
  }

  RemoteOffloadClient(RemoteOffloadClient &&C) = default;

  ~RemoteOffloadClient() {
    for (auto &TableIt : DevicesToTables)
      freeTargetTable(TableIt.second.get());
  }

  int32_t shutdown(void);

  int32_t registerLib(__tgt_bin_desc *Desc);
  int32_t unregisterLib(__tgt_bin_desc *Desc);

  int32_t isValidBinary(__tgt_device_image *Image);
  int32_t getNumberOfDevices();

  int32_t initDevice(int32_t DeviceId);
  int32_t initRequires(int64_t RequiresFlags);

  __tgt_target_table *loadBinary(int32_t DeviceId, __tgt_device_image *Image);
  int64_t synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfo);
  int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId);

  void *dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr);
  int32_t dataDelete(int32_t DeviceId, void *TgtPtr);

  int32_t dataSubmitAsync(int32_t DeviceId, void *TgtPtr, void *HstPtr,
                          int64_t Size, __tgt_async_info *AsyncInfo);
  int32_t dataRetrieveAsync(int32_t DeviceId, void *HstPtr, void *TgtPtr,
                            int64_t Size, __tgt_async_info *AsyncInfo);

  int32_t dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId,
                            void *DstPtr, int64_t Size,
                            __tgt_async_info *AsyncInfo);

  int32_t runTargetRegionAsync(int32_t DeviceId, void *TgtEntryPtr,
                               void **TgtArgs, ptrdiff_t *TgtOffsets,
                               int32_t ArgNum, __tgt_async_info *AsyncInfo);

  int32_t runTargetTeamRegionAsync(int32_t DeviceId, void *TgtEntryPtr,
                                   void **TgtArgs, ptrdiff_t *TgtOffsets,
                                   int32_t ArgNum, int32_t TeamNum,
                                   int32_t ThreadLimit, uint64_t LoopTripCount,
                                   __tgt_async_info *AsyncInfo);
};

class RemoteClientManager {
private:
  std::vector<std::string> Addresses;
  std::vector<RemoteOffloadClient> Clients;
  std::vector<int> Devices;

  std::pair<int32_t, int32_t> mapDeviceId(int32_t DeviceId);
  int DebugLevel;

public:
  RemoteClientManager(std::vector<std::string> Addresses, int Timeout,
                      uint64_t MaxSize, int64_t BlockSize)
      : Addresses(Addresses) {
    grpc::ChannelArguments ChArgs;
    ChArgs.SetMaxReceiveMessageSize(-1);
    DebugLevel = getDebugLevel();
    for (auto Address : Addresses) {
      Clients.push_back(RemoteOffloadClient(
          grpc::CreateChannel(Address, grpc::InsecureChannelCredentials()),
          Timeout, MaxSize, BlockSize));
    }
  }

  int32_t shutdown(void);

  int32_t registerLib(__tgt_bin_desc *Desc);
  int32_t unregisterLib(__tgt_bin_desc *Desc);

  int32_t isValidBinary(__tgt_device_image *Image);
  int32_t getNumberOfDevices();

  int32_t initDevice(int32_t DeviceId);
  int32_t initRequires(int64_t RequiresFlags);

  __tgt_target_table *loadBinary(int32_t DeviceId, __tgt_device_image *Image);
  int64_t synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfo);
  int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId);

  void *dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr);
  int32_t dataDelete(int32_t DeviceId, void *TgtPtr);

  int32_t dataSubmitAsync(int32_t DeviceId, void *TgtPtr, void *HstPtr,
                          int64_t Size, __tgt_async_info *AsyncInfo);
  int32_t dataRetrieveAsync(int32_t DeviceId, void *HstPtr, void *TgtPtr,
                            int64_t Size, __tgt_async_info *AsyncInfo);

  int32_t dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId,
                            void *DstPtr, int64_t Size,
                            __tgt_async_info *AsyncInfo);

  int32_t runTargetRegionAsync(int32_t DeviceId, void *TgtEntryPtr,
                               void **TgtArgs, ptrdiff_t *TgtOffsets,
                               int32_t ArgNum, __tgt_async_info *AsyncInfo);

  int32_t runTargetTeamRegionAsync(int32_t DeviceId, void *TgtEntryPtr,
                                   void **TgtArgs, ptrdiff_t *TgtOffsets,
                                   int32_t ArgNum, int32_t TeamNum,
                                   int32_t ThreadLimit, uint64_t LoopTripCount,
                                   __tgt_async_info *AsyncInfo);
};

#endif
