[OpenMP][FIX] Ensure thread states do not crash on the GPU

The nested parallelism causes thread states which still do not properly
work but at least don't crash anymore.

GitOrigin-RevId: d571af7f627491841fab7c456f774d7b8f546159
diff --git a/libomptarget/DeviceRTL/include/LibC.h b/libomptarget/DeviceRTL/include/LibC.h
index 87eed20..dde86af 100644
--- a/libomptarget/DeviceRTL/include/LibC.h
+++ b/libomptarget/DeviceRTL/include/LibC.h
@@ -17,6 +17,7 @@
 extern "C" {
 
 int memcmp(const void *lhs, const void *rhs, size_t count);
+void memset(void *dst, int C, size_t count);
 
 int printf(const char *format, ...);
 }
diff --git a/libomptarget/DeviceRTL/src/LibC.cpp b/libomptarget/DeviceRTL/src/LibC.cpp
index ae73a64..af675b9 100644
--- a/libomptarget/DeviceRTL/src/LibC.cpp
+++ b/libomptarget/DeviceRTL/src/LibC.cpp
@@ -47,6 +47,12 @@
   return 0;
 }
 
+void memset(void *dst, int C, size_t count) {
+  auto *dstc = reinterpret_cast<char *>(dst);
+  for (size_t I = 0; I < count; ++I)
+    dstc[I] = C;
+}
+
 /// printf() calls are rewritten by CGGPUBuiltin to __llvm_omp_vprintf
 int32_t __llvm_omp_vprintf(const char *Format, void *Arguments, uint32_t Size) {
   return impl::omp_vprintf(Format, Arguments, Size);
diff --git a/libomptarget/DeviceRTL/src/State.cpp b/libomptarget/DeviceRTL/src/State.cpp
index 68fe0b3..efa0502 100644
--- a/libomptarget/DeviceRTL/src/State.cpp
+++ b/libomptarget/DeviceRTL/src/State.cpp
@@ -12,6 +12,7 @@
 #include "Debug.h"
 #include "Environment.h"
 #include "Interface.h"
+#include "LibC.h"
 #include "Mapping.h"
 #include "Synchronization.h"
 #include "Types.h"
@@ -263,13 +264,14 @@
     return;
 
   unsigned TId = mapping::getThreadIdInBlock();
-  ThreadStateTy *NewThreadState =
-      static_cast<ThreadStateTy *>(__kmpc_alloc_shared(sizeof(ThreadStateTy)));
+  ThreadStateTy *NewThreadState = static_cast<ThreadStateTy *>(
+      memory::allocGlobal(sizeof(ThreadStateTy), "ThreadStates alloc"));
   uintptr_t *ThreadStatesBitsPtr = reinterpret_cast<uintptr_t *>(&ThreadStates);
   if (!atomic::load(ThreadStatesBitsPtr, atomic::seq_cst)) {
     uint32_t Bytes = sizeof(ThreadStates[0]) * mapping::getMaxTeamThreads();
     void *ThreadStatesPtr =
         memory::allocGlobal(Bytes, "Thread state array allocation");
+    memset(ThreadStatesPtr, '0', Bytes);
     if (!atomic::cas(ThreadStatesBitsPtr, uintptr_t(0),
                      reinterpret_cast<uintptr_t>(ThreadStatesPtr),
                      atomic::seq_cst, atomic::seq_cst))
@@ -298,7 +300,7 @@
     return;
 
   ThreadStateTy *PreviousThreadState = ThreadStates[TId]->PreviousThreadState;
-  __kmpc_free_shared(ThreadStates[TId], sizeof(ThreadStateTy));
+  memory::freeGlobal(ThreadStates[TId], "ThreadStates dealloc");
   ThreadStates[TId] = PreviousThreadState;
 }
 
diff --git a/libomptarget/test/offloading/thread_state_1.c b/libomptarget/test/offloading/thread_state_1.c
new file mode 100644
index 0000000..8725120
--- /dev/null
+++ b/libomptarget/test/offloading/thread_state_1.c
@@ -0,0 +1,36 @@
+// RUN: %libomptarget-compile-run-and-check-generic
+// RUN: %libomptarget-compileopt-run-and-check-generic
+
+// These are supported and work, but we compute bogus results on the GPU. For
+// now we disable the CPU and enable it once the GPU is fixed.
+//
+// UNSUPPORTED: aarch64-unknown-linux-gnu
+// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu
+// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+
+#include <omp.h>
+#include <stdio.h>
+
+int main() {
+  // TODO: Test all ICVs
+  int lvl = 333, tid = 666, nt = 999;
+#pragma omp target teams map(tofrom : lvl, tid, nt) num_teams(2)
+  {
+    if (omp_get_team_num() == 0) {
+#pragma omp parallel num_threads(128)
+      if (omp_get_thread_num() == 17) {
+#pragma omp parallel num_threads(64)
+        if (omp_get_thread_num() == omp_get_num_threads() - 1) {
+          lvl = omp_get_level();
+          tid = omp_get_thread_num();
+          nt = omp_get_num_threads();
+        }
+      }
+    }
+  }
+  // TODO: This is wrong, but at least it doesn't crash
+  // CHECK: lvl: 333, tid: 666, nt: 999
+  printf("lvl: %i, tid: %i, nt: %i\n", lvl, tid, nt);
+  return 0;
+}