[OPENMP][NVPTX]Correctly handle L2 parallelism in SPMD mode.

Summary:
The parallelLevel counter must be on per-thread basis to fully support
L2+ parallelism, otherwise we may end up with undefined behavior.
Introduce the parallelLevel on per-warp basis using shared memory. It
allows to avoid the problems with the synchronization and allows fully
support L2+ parallelism in SPMD mode with no runtime.

Reviewers: gtbercea, grokos

Subscribers: guansong, jdoerfert, caomhin, kkwli0, openmp-commits

Tags: #openmp

Differential Revision: https://reviews.llvm.org/D60918

git-svn-id: https://llvm.org/svn/llvm-project/openmp/trunk@359341 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/libomptarget/deviceRTLs/nvptx/src/libcall.cu b/libomptarget/deviceRTLs/nvptx/src/libcall.cu
index 9bc3f2c..c3877a8 100644
--- a/libomptarget/deviceRTLs/nvptx/src/libcall.cu
+++ b/libomptarget/deviceRTLs/nvptx/src/libcall.cu
@@ -165,7 +165,7 @@
     ASSERT0(LT_FUSSY, isSPMDMode(),
             "Expected SPMD mode only with uninitialized runtime.");
     // parallelLevel starts from 0, need to add 1 for correct level.
-    return parallelLevel + 1;
+    return parallelLevel[GetWarpId()] + 1;
   }
   int level = 0;
   omptarget_nvptx_TaskDescr *currTaskDescr =
diff --git a/libomptarget/deviceRTLs/nvptx/src/omp_data.cu b/libomptarget/deviceRTLs/nvptx/src/omp_data.cu
index 27666ae..1c2da46 100644
--- a/libomptarget/deviceRTLs/nvptx/src/omp_data.cu
+++ b/libomptarget/deviceRTLs/nvptx/src/omp_data.cu
@@ -31,7 +31,7 @@
 __device__ __shared__ uint32_t usedMemIdx;
 __device__ __shared__ uint32_t usedSlotIdx;
 
-__device__ __shared__ uint8_t parallelLevel;
+__device__ __shared__ uint8_t parallelLevel[MAX_THREADS_PER_TEAM / WARPSIZE];
 
 // Pointer to this team's OpenMP state object
 __device__ __shared__
diff --git a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu
index b69a3be..b5cfac3 100644
--- a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu
+++ b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.cu
@@ -95,8 +95,10 @@
     // If OMP runtime is not required don't initialize OMP state.
     setExecutionParameters(Spmd, RuntimeUninitialized);
     if (GetThreadIdInBlock() == 0) {
-      parallelLevel = 0;
       usedSlotIdx = smid() % MAX_SM;
+      parallelLevel[0] = 0;
+    } else if (GetLaneId() == 0) {
+      parallelLevel[GetWarpId()] = 0;
     }
     __SYNCTHREADS();
     return;
diff --git a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h
index d178b57..8ee69de 100644
--- a/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h
+++ b/libomptarget/deviceRTLs/nvptx/src/omptarget-nvptx.h
@@ -406,7 +406,8 @@
     omptarget_nvptx_simpleMemoryManager;
 extern __device__ __shared__ uint32_t usedMemIdx;
 extern __device__ __shared__ uint32_t usedSlotIdx;
-extern __device__ __shared__ uint8_t parallelLevel;
+extern __device__ __shared__ uint8_t
+    parallelLevel[MAX_THREADS_PER_TEAM / WARPSIZE];
 extern __device__ __shared__
     omptarget_nvptx_ThreadPrivateContext *omptarget_nvptx_threadPrivateContext;
 
diff --git a/libomptarget/deviceRTLs/nvptx/src/parallel.cu b/libomptarget/deviceRTLs/nvptx/src/parallel.cu
index 273e15c..1ad72d5 100644
--- a/libomptarget/deviceRTLs/nvptx/src/parallel.cu
+++ b/libomptarget/deviceRTLs/nvptx/src/parallel.cu
@@ -339,10 +339,12 @@
   if (checkRuntimeUninitialized(loc)) {
     ASSERT0(LT_FUSSY, checkSPMDMode(loc),
             "Expected SPMD mode with uninitialized runtime.");
-    __SYNCTHREADS();
-    if (GetThreadIdInBlock() == 0)
-      ++parallelLevel;
-    __SYNCTHREADS();
+    unsigned tnum = __ACTIVEMASK();
+    int leader = __ffs(tnum) - 1;
+    __SHFL_SYNC(tnum, leader, leader);
+    if (GetLaneId() == leader)
+      ++parallelLevel[GetWarpId()];
+    __SHFL_SYNC(tnum, leader, leader);
 
     return;
   }
@@ -382,10 +384,12 @@
   if (checkRuntimeUninitialized(loc)) {
     ASSERT0(LT_FUSSY, checkSPMDMode(loc),
             "Expected SPMD mode with uninitialized runtime.");
-    __SYNCTHREADS();
-    if (GetThreadIdInBlock() == 0)
-      --parallelLevel;
-    __SYNCTHREADS();
+    unsigned tnum = __ACTIVEMASK();
+    int leader = __ffs(tnum) - 1;
+    __SHFL_SYNC(tnum, leader, leader);
+    if (GetLaneId() == leader)
+      --parallelLevel[GetWarpId()];
+    __SHFL_SYNC(tnum, leader, leader);
     return;
   }
 
@@ -407,7 +411,7 @@
   if (checkRuntimeUninitialized(loc)) {
     ASSERT0(LT_FUSSY, checkSPMDMode(loc),
             "Expected SPMD mode with uninitialized runtime.");
-    return parallelLevel + 1;
+    return parallelLevel[GetWarpId()] + 1;
   }
 
   int threadId = GetLogicalThreadIdInBlock(checkSPMDMode(loc));
diff --git a/libomptarget/deviceRTLs/nvptx/src/support.h b/libomptarget/deviceRTLs/nvptx/src/support.h
index 5d4b403..f84da6d 100644
--- a/libomptarget/deviceRTLs/nvptx/src/support.h
+++ b/libomptarget/deviceRTLs/nvptx/src/support.h
@@ -40,6 +40,8 @@
 INLINE int GetBlockIdInKernel();
 INLINE int GetNumberOfBlocksInKernel();
 INLINE int GetNumberOfThreadsInBlock();
+INLINE unsigned GetWarpId();
+INLINE unsigned GetLaneId();
 
 // get global ids to locate tread/team info (constant regardless of OMP)
 INLINE int GetLogicalThreadIdInBlock(bool isSPMDExecutionMode);
diff --git a/libomptarget/deviceRTLs/nvptx/src/supporti.h b/libomptarget/deviceRTLs/nvptx/src/supporti.h
index 2052c17..3f313a9 100644
--- a/libomptarget/deviceRTLs/nvptx/src/supporti.h
+++ b/libomptarget/deviceRTLs/nvptx/src/supporti.h
@@ -102,6 +102,10 @@
 
 INLINE int GetNumberOfThreadsInBlock() { return blockDim.x; }
 
+INLINE unsigned GetWarpId() { return threadIdx.x / WARPSIZE; }
+
+INLINE unsigned GetLaneId() { return threadIdx.x & (WARPSIZE - 1); }
+
 ////////////////////////////////////////////////////////////////////////////////
 //
 // Calls to the Generic Scheme Implementation Layer (assuming 1D layout)
@@ -154,7 +158,7 @@
     ASSERT0(LT_FUSSY, isSPMDExecutionMode,
             "Uninitialized runtime with non-SPMD mode.");
     // For level 2 parallelism all parallel regions are executed sequentially.
-    if (parallelLevel > 0)
+    if (parallelLevel[GetWarpId()] > 0)
       rc = 0;
     else
       rc = GetThreadIdInBlock();
@@ -175,7 +179,7 @@
     ASSERT0(LT_FUSSY, isSPMDExecutionMode,
             "Uninitialized runtime with non-SPMD mode.");
     // For level 2 parallelism all parallel regions are executed sequentially.
-    if (parallelLevel > 0)
+    if (parallelLevel[GetWarpId()] > 0)
       rc = 1;
     else
       rc = GetNumberOfThreadsInBlock();
diff --git a/libomptarget/deviceRTLs/nvptx/test/parallel/spmd_parallel_regions.cpp b/libomptarget/deviceRTLs/nvptx/test/parallel/spmd_parallel_regions.cpp
index 3e08d1d..517db59 100644
--- a/libomptarget/deviceRTLs/nvptx/test/parallel/spmd_parallel_regions.cpp
+++ b/libomptarget/deviceRTLs/nvptx/test/parallel/spmd_parallel_regions.cpp
@@ -6,24 +6,31 @@
 int main(void) {
   int isHost = -1;
   int ParallelLevel1 = -1, ParallelLevel2 = -1;
+  int Count = 0;
 
 #pragma omp target parallel for map(tofrom                                     \
-                                    : isHost, ParallelLevel1, ParallelLevel2)
+                                    : isHost, ParallelLevel1, ParallelLevel2), reduction(+: Count) schedule(static, 1)
   for (int J = 0; J < 10; ++J) {
 #pragma omp critical
     {
-      isHost = (isHost < 0 || isHost == omp_is_initial_device())
-                   ? omp_is_initial_device()
-                   : 1;
-      ParallelLevel1 =
-          (ParallelLevel1 < 0 || ParallelLevel1 == 1) ? omp_get_level() : 2;
+      isHost = (isHost < 0 || isHost == 0) ? omp_is_initial_device() : isHost;
+      ParallelLevel1 = (ParallelLevel1 < 0 || ParallelLevel1 == 1)
+                           ? omp_get_level()
+                           : ParallelLevel1;
     }
-    int L2;
-#pragma omp parallel for schedule(dynamic) lastprivate(L2)
-    for (int I = 0; I < 10; ++I)
-      L2 = omp_get_level();
+    if (omp_get_thread_num() > 5) {
+      int L2;
+#pragma omp parallel for schedule(dynamic) lastprivate(L2) reduction(+: Count)
+      for (int I = 0; I < 10; ++I) {
+        L2 = omp_get_level();
+        Count += omp_get_level(); // (10-6)*10*2 = 80
+      }
 #pragma omp critical
-    ParallelLevel2 = (ParallelLevel2 < 0 || ParallelLevel2 == 2) ? L2 : 1;
+      ParallelLevel2 =
+          (ParallelLevel2 < 0 || ParallelLevel2 == 2) ? L2 : ParallelLevel2;
+    } else {
+      Count += omp_get_level(); // 6 * 1 = 6
+    }
   }
 
   if (isHost < 0) {
@@ -35,6 +42,10 @@
   // CHECK: Parallel level in SPMD mode: L1 is 1, L2 is 2
   printf("Parallel level in SPMD mode: L1 is %d, L2 is %d\n", ParallelLevel1,
          ParallelLevel2);
+  // Final result of Count is (10-6)(num of loops)*10(num of iterations)*2(par
+  // level) + 6(num of iterations) * 1(par level)
+  // CHECK: Expected count = 86
+  printf("Expected count = %d\n", Count);
 
   return isHost;
 }