[OPENMP][NVPTX]Simplify handling of thread limit, NFC.

Summary:
Patch improves performance of the full runtime mode by moving
threads limit counter to the shared memory. It also allows to save
global memory.

Reviewers: grokos, kkwli0, gtbercea

Subscribers: guansong, jdoerfert, openmp-commits, caomhin

Tags: #openmp

Differential Revision: https://reviews.llvm.org/D61801

git-svn-id: https://llvm.org/svn/llvm-project/openmp/trunk@360584 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/libomptarget/deviceRTLs/nvptx/src/libcall.cu b/libomptarget/deviceRTLs/nvptx/src/libcall.cu
index 9c6d136..9580d75 100644
--- a/libomptarget/deviceRTLs/nvptx/src/libcall.cu
+++ b/libomptarget/deviceRTLs/nvptx/src/libcall.cu
@@ -37,10 +37,8 @@
   PRINT(LD_IO, "call omp_set_num_threads(num %d)\n", num);
   if (num <= 0) {
     WARNING0(LW_INPUT, "expected positive num; ignore\n");
-  } else {
-    omptarget_nvptx_TaskDescr *currTaskDescr =
-        getMyTopTaskDescriptor(/*isSPMDExecutionMode=*/false);
-    currTaskDescr->NThreads() = num;
+  } else if (parallelLevel[GetWarpId()] == 0) {
+    nThreads = num;
   }
 }
 
@@ -54,12 +52,10 @@
   if (parallelLevel[GetWarpId()] > 0)
     // We're already in parallel region.
     return 1; // default is 1 thread avail
-  omptarget_nvptx_TaskDescr *currTaskDescr =
-      getMyTopTaskDescriptor(/*isSPMDExecutionMode=*/false);
-  ASSERT0(LT_FUSSY, !currTaskDescr->InParallelRegion(),
-          "Should no be in the parallel region");
   // Not currently in a parallel region, return what was set.
-  int rc = currTaskDescr->NThreads();
+  int rc = 1;
+  if (parallelLevel[GetWarpId()] == 0)
+    rc = nThreads;
   ASSERT0(LT_FUSSY, rc >= 0, "bad number of threads");
   PRINT(LD_IO, "call omp_get_max_threads() return %d\n", rc);
   return rc;
@@ -175,7 +171,7 @@
                 (int)currTaskDescr->InParallelRegion(), (int)sched,
                 currTaskDescr->RuntimeChunkSize(),
                 (int)currTaskDescr->ThreadId(), (int)threadsInTeam,
-                (int)currTaskDescr->NThreads());
+                (int)nThreads);
         }
 
         if (currTaskDescr->IsParallelConstruct()) {
diff --git a/libomptarget/deviceRTLs/nvptx/src/omp_data.cu b/libomptarget/deviceRTLs/nvptx/src/omp_data.cu
index f7cd334..d369da1 100644
--- a/libomptarget/deviceRTLs/nvptx/src/omp_data.cu
+++ b/libomptarget/deviceRTLs/nvptx/src/omp_data.cu
@@ -34,6 +34,7 @@
 __device__ __shared__ uint8_t parallelLevel[MAX_THREADS_PER_TEAM / WARPSIZE];
 __device__ __shared__ uint16_t threadLimit;
 __device__ __shared__ uint16_t threadsInTeam;
+__device__ __shared__ uint16_t nThreads;
 // Pointer to this team's OpenMP state object
 __device__ __shared__
     omptarget_nvptx_ThreadPrivateContext *omptarget_nvptx_threadPrivateContext;
diff --git a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu
index 5c861e3..706776a 100644
--- a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu
+++ b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu
@@ -63,7 +63,7 @@
 
   // init team context
   omptarget_nvptx_TeamDescr &currTeamDescr = getMyTeamDescriptor();
-  currTeamDescr.InitTeamDescr(/*isSPMDExecutionMode=*/false);
+  currTeamDescr.InitTeamDescr();
   // this thread will start execution... has to update its task ICV
   // to point to the level zero task ICV. That ICV was init in
   // InitTeamDescr()
@@ -73,7 +73,7 @@
   // set number of threads and thread limit in team to started value
   omptarget_nvptx_TaskDescr *currTaskDescr =
       omptarget_nvptx_threadPrivateContext->GetTopLevelTaskDescr(threadId);
-  currTaskDescr->NThreads() = GetNumberOfWorkersInTeam();
+  nThreads = GetNumberOfWorkersInTeam();
   threadLimit = ThreadLimit;
 }
 
@@ -123,7 +123,7 @@
     omptarget_nvptx_TeamDescr &currTeamDescr = getMyTeamDescriptor();
     omptarget_nvptx_WorkDescr &workDescr = getMyWorkDescriptor();
     // init team context
-    currTeamDescr.InitTeamDescr(/*isSPMDExecutionMode=*/true);
+    currTeamDescr.InitTeamDescr();
   }
   // FIXME: use __syncthreads instead when the function copy is fixed in LLVM.
   __SYNCTHREADS();
diff --git a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h
index 9b6dcbf..cd51538 100644
--- a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h
+++ b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h
@@ -164,7 +164,6 @@
   }
   INLINE int IsTaskConstruct() const { return !IsParallelConstruct(); }
   // methods for other fields
-  INLINE uint16_t &NThreads() { return items.nthreads; }
   INLINE uint16_t &ThreadId() { return items.threadId; }
   INLINE uint64_t &RuntimeChunkSize() { return items.runtimeChunkSize; }
   INLINE omptarget_nvptx_TaskDescr *GetPrevTaskDescr() const { return prev; }
@@ -172,7 +171,7 @@
     prev = taskDescr;
   }
   // init & copy
-  INLINE void InitLevelZeroTaskDescr(bool isSPMDExecutionMode);
+  INLINE void InitLevelZeroTaskDescr();
   INLINE void InitLevelOneTaskDescr(omptarget_nvptx_TaskDescr *parentTaskDescr);
   INLINE void Copy(omptarget_nvptx_TaskDescr *sourceTaskDescr);
   INLINE void CopyData(omptarget_nvptx_TaskDescr *sourceTaskDescr);
@@ -208,7 +207,6 @@
   struct TaskDescr_items {
     uint8_t flags; // 6 bit used (see flag above)
     uint8_t unused;
-    uint16_t nthreads;         // thread num for subsequent parallel regions
     uint16_t threadId;         // thread id
     uint64_t runtimeChunkSize; // runtime chunk size
   } items;
@@ -249,7 +247,7 @@
   INLINE uint64_t *getLastprivateIterBuffer() { return &lastprivateIterBuffer; }
 
   // init
-  INLINE void InitTeamDescr(bool isSPMDExecutionMode);
+  INLINE void InitTeamDescr();
 
   INLINE __kmpc_data_sharing_slot *RootS(int wid, bool IsMasterThread) {
     // If this is invoked by the master thread of the master warp then intialize
@@ -404,6 +402,7 @@
     parallelLevel[MAX_THREADS_PER_TEAM / WARPSIZE];
 extern __device__ __shared__ uint16_t threadLimit;
 extern __device__ __shared__ uint16_t threadsInTeam;
+extern __device__ __shared__ uint16_t nThreads;
 extern __device__ __shared__
     omptarget_nvptx_ThreadPrivateContext *omptarget_nvptx_threadPrivateContext;
 
diff --git a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptxi.h b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptxi.h
index c890b95..e4efa18 100644
--- a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptxi.h
+++ b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptxi.h
@@ -31,7 +31,7 @@
 }
 
 INLINE void
-omptarget_nvptx_TaskDescr::InitLevelZeroTaskDescr(bool isSPMDExecutionMode) {
+omptarget_nvptx_TaskDescr::InitLevelZeroTaskDescr() {
   // slow method
   // flag:
   //   default sched is static,
@@ -39,8 +39,6 @@
   //   not in parallel
 
   items.flags = 0;
-  items.nthreads = GetNumberOfProcsInTeam(isSPMDExecutionMode);
-  ;                                // threads: whatever was alloc by kernel
   items.threadId = 0;         // is master
   items.runtimeChunkSize = 1; // prefered chunking statik with chunk 1
 }
@@ -57,7 +55,6 @@
 
   items.flags =
       TaskDescr_InPar | TaskDescr_IsParConstr; // set flag to parallel
-  items.nthreads = 0; // # threads for subsequent parallel region
   items.threadId =
       GetThreadIdInBlock(); // get ids from cuda (only called for 1st level)
   items.runtimeChunkSize = 1; // prefered chunking statik with chunk 1
@@ -173,8 +170,8 @@
 // Team Descriptor
 ////////////////////////////////////////////////////////////////////////////////
 
-INLINE void omptarget_nvptx_TeamDescr::InitTeamDescr(bool isSPMDExecutionMode) {
-  levelZeroTaskDescr.InitLevelZeroTaskDescr(isSPMDExecutionMode);
+INLINE void omptarget_nvptx_TeamDescr::InitTeamDescr() {
+  levelZeroTaskDescr.InitLevelZeroTaskDescr();
 }
 
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/libomptarget/deviceRTLs/nvptx/src/parallel.cu b/libomptarget/deviceRTLs/nvptx/src/parallel.cu
index 3f2ec44..6747235 100644
--- a/libomptarget/deviceRTLs/nvptx/src/parallel.cu
+++ b/libomptarget/deviceRTLs/nvptx/src/parallel.cu
@@ -249,8 +249,8 @@
   uint16_t &NumThreadsClause =
       omptarget_nvptx_threadPrivateContext->NumThreadsForNextParallel(threadId);
 
-  uint16_t NumThreads = determineNumberOfThreads(
-      NumThreadsClause, currTaskDescr->NThreads(), threadLimit);
+  uint16_t NumThreads =
+      determineNumberOfThreads(NumThreadsClause, nThreads, threadLimit);
 
   if (NumThreadsClause != 0) {
     // Reset request to avoid propagating to successive #parallel
@@ -308,7 +308,7 @@
     PRINT(LD_PAR,
           "thread will execute parallel region with id %d in a team of "
           "%d threads\n",
-          (int)newTaskDescr->ThreadId(), (int)newTaskDescr->NThreads());
+          (int)newTaskDescr->ThreadId(), (int)nThreads);
 
     isActive = true;
     IncParallelLevel(threadsInTeam != 1);