[openmp] Apply code change from D109500

(cherry picked from commit 71052ea1e3c63b7209731fdc1726d10640d97480)
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index d6b9791..75eec25 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -1996,7 +1996,8 @@
         UndefValue::get(Int8Ty), F->getName() + ".ID");
 
     for (Use *U : ToBeReplacedStateMachineUses)
-      U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
+      U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+          ID, U->get()->getType()));
 
     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
 
@@ -3183,10 +3184,14 @@
     IsWorker->setDebugLoc(DLoc);
     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
 
+    Module &M = *Kernel->getParent();
+
     // Create local storage for the work function pointer.
+    const DataLayout &DL = M.getDataLayout();
     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
-    AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr",
-                                          &Kernel->getEntryBlock().front());
+    Instruction *WorkFnAI =
+        new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
+                       "worker.work_fn.addr", &Kernel->getEntryBlock().front());
     WorkFnAI->setDebugLoc(DLoc);
 
     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
@@ -3199,13 +3204,23 @@
     Value *Ident = KernelInitCB->getArgOperand(0);
     Value *GTid = KernelInitCB;
 
-    Module &M = *Kernel->getParent();
     FunctionCallee BarrierFn =
         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
             M, OMPRTL___kmpc_barrier_simple_spmd);
     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
         ->setDebugLoc(DLoc);
 
+    if (WorkFnAI->getType()->getPointerAddressSpace() !=
+        (unsigned int)AddressSpace::Generic) {
+      WorkFnAI = new AddrSpaceCastInst(
+          WorkFnAI,
+          PointerType::getWithSamePointeeType(
+              cast<PointerType>(WorkFnAI->getType()),
+              (unsigned int)AddressSpace::Generic),
+          WorkFnAI->getName() + ".generic", StateMachineBeginBB);
+      WorkFnAI->setDebugLoc(DLoc);
+    }
+
     FunctionCallee KernelParallelFn =
         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
             M, OMPRTL___kmpc_kernel_parallel);