[OpenMP] Move synchronization into `__tgt_async_info`

The AsyncInfo should be passed everywhere and it should offer a way to
ensure synchronization, given a libomptarget Device.

This replaces D96431.

Reviewed By: tianshilei1992

Differential Revision: https://reviews.llvm.org/D96438

GitOrigin-RevId: a2fc0d34db72315f05b96681013e1e16d4df41c7
diff --git a/libomptarget/include/omptarget.h b/libomptarget/include/omptarget.h
index 7c70704..ffdf965 100644
--- a/libomptarget/include/omptarget.h
+++ b/libomptarget/include/omptarget.h
@@ -119,6 +119,8 @@
       *EntriesEnd; // End of the table with all the entries (non inclusive)
 };
 
+// clang-format on
+
 /// This struct contains information exchanged between different asynchronous
 /// operations for device-dependent optimization and potential synchronization
 struct __tgt_async_info {
@@ -128,6 +130,29 @@
   void *Queue = nullptr;
 };
 
+struct DeviceTy;
+
+/// The libomptarget wrapper around a __tgt_async_info object directly
+/// associated with a libomptarget layer device. RAII semantics to avoid
+/// mistakes.
+class AsyncInfoTy {
+  __tgt_async_info AsyncInfo;
+  DeviceTy &Device;
+
+public:
+  AsyncInfoTy(DeviceTy &Device) : Device(Device) {}
+  ~AsyncInfoTy() { synchronize(); }
+
+  /// Implicit conversion to the __tgt_async_info which is used in the
+  /// plugin interface.
+  operator __tgt_async_info *() { return &AsyncInfo; }
+
+  /// Synchronize all pending actions.
+  ///
+  /// \returns OFFLOAD_FAIL or OFFLOAD_SUCCESS appropriately.
+  int synchronize();
+};
+
 /// This struct is a record of non-contiguous information
 struct __tgt_target_non_contig {
   uint64_t Offset;
@@ -135,8 +160,6 @@
   uint64_t Stride;
 };
 
-// clang-format on
-
 #ifdef __cplusplus
 extern "C" {
 #endif
diff --git a/libomptarget/src/omptarget.cpp b/libomptarget/src/omptarget.cpp
index e133012..dc49345 100644
--- a/libomptarget/src/omptarget.cpp
+++ b/libomptarget/src/omptarget.cpp
@@ -19,6 +19,18 @@
 #include <cassert>
 #include <vector>
 
+int AsyncInfoTy::synchronize() {
+  int Result = OFFLOAD_SUCCESS;
+  if (AsyncInfo.Queue) {
+    // If we have a queue we need to synchronize it now.
+    Result = Device.synchronize(&AsyncInfo);
+    assert(AsyncInfo.Queue == nullptr &&
+           "The device plugin should have nulled the queue to indicate there "
+           "are no outstanding actions!");
+  }
+  return Result;
+}
+
 /* All begin addresses for partially mapped structs must be 8-aligned in order
  * to ensure proper alignment of members. E.g.
  *
@@ -248,7 +260,7 @@
                                 MapperArgsBase.data(), MapperArgs.data(),
                                 MapperArgSizes.data(), MapperArgTypes.data(),
                                 MapperArgNames.data(), /*arg_mappers*/ nullptr,
-                                /*__tgt_async_info*/ nullptr);
+                                /* AsyncInfoTy */ nullptr);
 
   return rc;
 }
@@ -257,7 +269,7 @@
 int targetDataBegin(ident_t *loc, DeviceTy &Device, int32_t arg_num,
                     void **args_base, void **args, int64_t *arg_sizes,
                     int64_t *arg_types, map_var_info_t *arg_names,
-                    void **arg_mappers, __tgt_async_info *AsyncInfo) {
+                    void **arg_mappers, AsyncInfoTy *AsyncInfo) {
   // process each input.
   for (int32_t i = 0; i < arg_num; ++i) {
     // Ignore private variables and arrays - there is no mapping for them.
@@ -404,7 +416,7 @@
         DP("Moving %" PRId64 " bytes (hst:" DPxMOD ") -> (tgt:" DPxMOD ")\n",
            data_size, DPxPTR(HstPtrBegin), DPxPTR(TgtPtrBegin));
         int rt =
-            Device.submitData(TgtPtrBegin, HstPtrBegin, data_size, AsyncInfo);
+            Device.submitData(TgtPtrBegin, HstPtrBegin, data_size, *AsyncInfo);
         if (rt != OFFLOAD_SUCCESS) {
           REPORT("Copying data to device failed.\n");
           return OFFLOAD_FAIL;
@@ -418,7 +430,7 @@
       uint64_t Delta = (uint64_t)HstPtrBegin - (uint64_t)HstPtrBase;
       void *TgtPtrBase = (void *)((uint64_t)TgtPtrBegin - Delta);
       int rt = Device.submitData(PointerTgtPtrBegin, &TgtPtrBase,
-                                 sizeof(void *), AsyncInfo);
+                                 sizeof(void *), *AsyncInfo);
       if (rt != OFFLOAD_SUCCESS) {
         REPORT("Copying data to device failed.\n");
         return OFFLOAD_FAIL;
@@ -452,24 +464,13 @@
       : HstPtrBegin(HstPtr), DataSize(Size), ForceDelete(ForceDelete),
         HasCloseModifier(HasCloseModifier) {}
 };
-
-/// Synchronize device
-static int syncDevice(DeviceTy &Device, __tgt_async_info *AsyncInfo) {
-  assert(AsyncInfo && AsyncInfo->Queue && "Invalid AsyncInfo");
-  if (Device.synchronize(AsyncInfo) != OFFLOAD_SUCCESS) {
-    REPORT("Failed to synchronize device.\n");
-    return OFFLOAD_FAIL;
-  }
-
-  return OFFLOAD_SUCCESS;
-}
 } // namespace
 
 /// Internal function to undo the mapping and retrieve the data from the device.
 int targetDataEnd(ident_t *loc, DeviceTy &Device, int32_t ArgNum,
                   void **ArgBases, void **Args, int64_t *ArgSizes,
                   int64_t *ArgTypes, map_var_info_t *ArgNames,
-                  void **ArgMappers, __tgt_async_info *AsyncInfo) {
+                  void **ArgMappers, AsyncInfoTy *AsyncInfo) {
   int Ret;
   std::vector<DeallocTgtPtrInfo> DeallocTgtPtrs;
   // process each input.
@@ -584,7 +585,7 @@
           DP("Moving %" PRId64 " bytes (tgt:" DPxMOD ") -> (hst:" DPxMOD ")\n",
              DataSize, DPxPTR(TgtPtrBegin), DPxPTR(HstPtrBegin));
           Ret = Device.retrieveData(HstPtrBegin, TgtPtrBegin, DataSize,
-                                    AsyncInfo);
+                                    *AsyncInfo);
           if (Ret != OFFLOAD_SUCCESS) {
             REPORT("Copying data from device failed.\n");
             return OFFLOAD_FAIL;
@@ -642,8 +643,8 @@
   // nullptr, there is no data transfer happened because once there is,
   // AsyncInfo->Queue will not be nullptr, so again, we don't need to
   // synchronize.
-  if (AsyncInfo && AsyncInfo->Queue) {
-    Ret = syncDevice(Device, AsyncInfo);
+  if (AsyncInfo) {
+    Ret = AsyncInfo->synchronize();
     if (Ret != OFFLOAD_SUCCESS)
       return OFFLOAD_FAIL;
   }
@@ -798,7 +799,7 @@
 int targetDataUpdate(ident_t *loc, DeviceTy &Device, int32_t ArgNum,
                      void **ArgsBase, void **Args, int64_t *ArgSizes,
                      int64_t *ArgTypes, map_var_info_t *ArgNames,
-                     void **ArgMappers, __tgt_async_info *AsyncInfo) {
+                     void **ArgMappers, AsyncInfoTy *AsyncInfo) {
   // process each input.
   for (int32_t I = 0; I < ArgNum; ++I) {
     if ((ArgTypes[I] & OMP_TGT_MAPTYPE_LITERAL) ||
@@ -948,8 +949,8 @@
 
   /// A reference to the \p DeviceTy object
   DeviceTy &Device;
-  /// A pointer to a \p __tgt_async_info object
-  __tgt_async_info *AsyncInfo;
+  /// A pointer to a \p AsyncInfoTy object
+  AsyncInfoTy *AsyncInfo;
 
   // TODO: What would be the best value here? Should we make it configurable?
   // If the size is larger than this threshold, we will allocate and transfer it
@@ -958,7 +959,7 @@
 
 public:
   /// Constructor
-  PrivateArgumentManagerTy(DeviceTy &Dev, __tgt_async_info *AsyncInfo)
+  PrivateArgumentManagerTy(DeviceTy &Dev, AsyncInfoTy *AsyncInfo)
       : Device(Dev), AsyncInfo(AsyncInfo) {}
 
   /// Add a private argument
@@ -985,7 +986,7 @@
 #endif
       // If first-private, copy data from host
       if (IsFirstPrivate) {
-        int Ret = Device.submitData(TgtPtr, HstPtr, ArgSize, AsyncInfo);
+        int Ret = Device.submitData(TgtPtr, HstPtr, ArgSize, *AsyncInfo);
         if (Ret != OFFLOAD_SUCCESS) {
           DP("Copying data to device failed, failed.\n");
           return OFFLOAD_FAIL;
@@ -1041,7 +1042,7 @@
          FirstPrivateArgSize, DPxPTR(TgtPtr));
       // Transfer data to target device
       int Ret = Device.submitData(TgtPtr, FirstPrivateArgBuffer.data(),
-                                  FirstPrivateArgSize, AsyncInfo);
+                                  FirstPrivateArgSize, *AsyncInfo);
       if (Ret != OFFLOAD_SUCCESS) {
         DP("Failed to submit data of private arguments.\n");
         return OFFLOAD_FAIL;
@@ -1089,7 +1090,7 @@
                              std::vector<void *> &TgtArgs,
                              std::vector<ptrdiff_t> &TgtOffsets,
                              PrivateArgumentManagerTy &PrivateArgumentManager,
-                             __tgt_async_info *AsyncInfo) {
+                             AsyncInfoTy *AsyncInfo) {
   TIMESCOPE_WITH_NAME_AND_IDENT("mappingBeforeTargetRegion", loc);
   DeviceTy &Device = PM->Devices[DeviceId];
   int Ret = targetDataBegin(loc, Device, ArgNum, ArgBases, Args, ArgSizes,
@@ -1140,7 +1141,7 @@
         DP("Update lambda reference (" DPxMOD ") -> [" DPxMOD "]\n",
            DPxPTR(PointerTgtPtrBegin), DPxPTR(TgtPtrBegin));
         Ret = Device.submitData(TgtPtrBegin, &PointerTgtPtrBegin,
-                                sizeof(void *), AsyncInfo);
+                                sizeof(void *), *AsyncInfo);
         if (Ret != OFFLOAD_SUCCESS) {
           REPORT("Copying data to device failed.\n");
           return OFFLOAD_FAIL;
@@ -1210,7 +1211,7 @@
                             int64_t *ArgSizes, int64_t *ArgTypes,
                             map_var_info_t *ArgNames, void **ArgMappers,
                             PrivateArgumentManagerTy &PrivateArgumentManager,
-                            __tgt_async_info *AsyncInfo) {
+                            AsyncInfoTy *AsyncInfo) {
   TIMESCOPE_WITH_NAME_AND_IDENT("mappingAfterTargetRegion", loc);
   DeviceTy &Device = PM->Devices[DeviceId];
 
@@ -1242,8 +1243,7 @@
 int target(ident_t *loc, DeviceTy &Device, void *HostPtr, int32_t ArgNum,
            void **ArgBases, void **Args, int64_t *ArgSizes, int64_t *ArgTypes,
            map_var_info_t *ArgNames, void **ArgMappers, int32_t TeamNum,
-           int32_t ThreadLimit, int IsTeamConstruct,
-           __tgt_async_info *AsyncInfo) {
+           int32_t ThreadLimit, int IsTeamConstruct, AsyncInfoTy *AsyncInfo) {
   int32_t DeviceId = Device.DeviceID;
 
   TableMap *TM = getTableMap(HostPtr);
@@ -1266,7 +1266,7 @@
 
   // TODO: This will go away as soon as we consequently pass in async info
   // objects (as references).
-  __tgt_async_info InternalAsyncInfo;
+  AsyncInfoTy InternalAsyncInfo(Device);
   if (!AsyncInfo)
     AsyncInfo = &InternalAsyncInfo;
 
@@ -1301,10 +1301,10 @@
     if (IsTeamConstruct)
       Ret = Device.runTeamRegion(TgtEntryPtr, &TgtArgs[0], &TgtOffsets[0],
                                  TgtArgs.size(), TeamNum, ThreadLimit,
-                                 LoopTripCount, AsyncInfo);
+                                 LoopTripCount, *AsyncInfo);
     else
       Ret = Device.runRegion(TgtEntryPtr, &TgtArgs[0], &TgtOffsets[0],
-                             TgtArgs.size(), AsyncInfo);
+                             TgtArgs.size(), *AsyncInfo);
   }
 
   if (Ret != OFFLOAD_SUCCESS) {
@@ -1322,11 +1322,13 @@
       REPORT("Failed to process data after launching the kernel.\n");
       return OFFLOAD_FAIL;
     }
-  } else if (AsyncInfo->Queue) {
+  } else {
+    // TODO: We should not synchronize here but on the outer level once we pass
+    // in a reference AsyncInfo object.
     // If ArgNum is zero, but AsyncInfo.Queue is valid, then the kernel doesn't
     // hava any argument, and the device supports async operations, so we need a
     // sync at this point.
-    return syncDevice(Device, AsyncInfo);
+    return AsyncInfo->synchronize();
   }
 
   return OFFLOAD_SUCCESS;
diff --git a/libomptarget/src/private.h b/libomptarget/src/private.h
index 43d9d4a..746eea2 100644
--- a/libomptarget/src/private.h
+++ b/libomptarget/src/private.h
@@ -23,23 +23,23 @@
 extern int targetDataBegin(ident_t *loc, DeviceTy &Device, int32_t arg_num,
                            void **args_base, void **args, int64_t *arg_sizes,
                            int64_t *arg_types, map_var_info_t *arg_names,
-                           void **arg_mappers, __tgt_async_info *AsyncInfo);
+                           void **arg_mappers, AsyncInfoTy *AsyncInfo);
 
 extern int targetDataEnd(ident_t *loc, DeviceTy &Device, int32_t ArgNum,
                          void **ArgBases, void **Args, int64_t *ArgSizes,
                          int64_t *ArgTypes, map_var_info_t *arg_names,
-                         void **ArgMappers, __tgt_async_info *AsyncInfo);
+                         void **ArgMappers, AsyncInfoTy *AsyncInfo);
 
 extern int targetDataUpdate(ident_t *loc, DeviceTy &Device, int32_t arg_num,
                             void **args_base, void **args, int64_t *arg_sizes,
                             int64_t *arg_types, map_var_info_t *arg_names,
-                            void **arg_mappers, __tgt_async_info *AsyncInfo);
+                            void **arg_mappers, AsyncInfoTy *AsyncInfo);
 
 extern int target(ident_t *loc, DeviceTy &Device, void *HostPtr, int32_t ArgNum,
                   void **ArgBases, void **Args, int64_t *ArgSizes,
                   int64_t *ArgTypes, map_var_info_t *arg_names,
                   void **ArgMappers, int32_t TeamNum, int32_t ThreadLimit,
-                  int IsTeamConstruct, __tgt_async_info *AsyncInfo);
+                  int IsTeamConstruct, AsyncInfoTy *AsyncInfo);
 
 extern int CheckDeviceAndCtors(int64_t device_id);
 
@@ -76,8 +76,7 @@
 // targetDataEnd and targetDataUpdate).
 typedef int (*TargetDataFuncPtrTy)(ident_t *, DeviceTy &, int32_t, void **,
                                    void **, int64_t *, int64_t *,
-                                   map_var_info_t *, void **,
-                                   __tgt_async_info *);
+                                   map_var_info_t *, void **, AsyncInfoTy *);
 
 // Implemented in libomp, they are called from within __tgt_* functions.
 #ifdef __cplusplus