[OpenMP] Added codegen for masked directive

Reviewed By: ABataev

Differential Revision: https://reviews.llvm.org/D100514

GitOrigin-RevId: e0c2125d1d1e72039b8e071d468d9f740c7dbfbd
diff --git a/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 997c078..5a4b406 100644
--- a/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -629,6 +629,17 @@
                              BodyGenCallbackTy BodyGenCB,
                              FinalizeCallbackTy FiniCB);
 
+  /// Generator for '#omp masked'
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param BodyGenCB Callback that will generate the region code.
+  /// \param FiniCB Callback to finialize variable copies.
+  ///
+  /// \returns The insertion position *after* the master.
+  InsertPointTy createMasked(const LocationDescription &Loc,
+                             BodyGenCallbackTy BodyGenCB,
+                             FinalizeCallbackTy FiniCB, Value *Filter);
+
   /// Generator for '#omp critical'
   ///
   /// \param Loc The insert and source location description.
diff --git a/include/llvm/Frontend/OpenMP/OMPKinds.def b/include/llvm/Frontend/OpenMP/OMPKinds.def
index 533a10f..5b403e0 100644
--- a/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -237,6 +237,8 @@
 
 __OMP_RTL(__kmpc_master, false, Int32, IdentPtr, Int32)
 __OMP_RTL(__kmpc_end_master, false, Void, IdentPtr, Int32)
+__OMP_RTL(__kmpc_masked, false, Int32, IdentPtr, Int32, Int32)
+__OMP_RTL(__kmpc_end_masked, false, Void, IdentPtr, Int32)
 __OMP_RTL(__kmpc_critical, false, Void, IdentPtr, Int32, KmpCriticalNamePtrTy)
 __OMP_RTL(__kmpc_critical_with_hint, false, Void, IdentPtr, Int32,
           KmpCriticalNamePtrTy, Int32)
@@ -640,6 +642,10 @@
                 ParamAttrs(ReadOnlyPtrAttrs))
 __OMP_RTL_ATTRS(__kmpc_end_master, InaccessibleArgOnlyAttrs, AttributeSet(),
                 ParamAttrs(ReadOnlyPtrAttrs))
+__OMP_RTL_ATTRS(__kmpc_masked, InaccessibleArgOnlyAttrs, AttributeSet(),
+                ParamAttrs(ReadOnlyPtrAttrs))
+__OMP_RTL_ATTRS(__kmpc_end_masked, InaccessibleArgOnlyAttrs, AttributeSet(),
+                ParamAttrs(ReadOnlyPtrAttrs))
 __OMP_RTL_ATTRS(__kmpc_critical, BarrierAttrs, AttributeSet(),
                 ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), AttributeSet()))
 __OMP_RTL_ATTRS(__kmpc_critical_with_hint, BarrierAttrs, AttributeSet(),
diff --git a/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 26f5901..ec9ecce 100644
--- a/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -901,6 +901,30 @@
                               /*Conditional*/ true, /*hasFinalize*/ true);
 }
 
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
+                              BodyGenCallbackTy BodyGenCB,
+                              FinalizeCallbackTy FiniCB, Value *Filter) {
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  Directive OMPD = Directive::OMPD_masked;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
+  Value *Ident = getOrCreateIdent(SrcLocStr);
+  Value *ThreadId = getOrCreateThreadID(Ident);
+  Value *Args[] = {Ident, ThreadId, Filter};
+  Value *ArgsEnd[] = {Ident, ThreadId};
+
+  Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
+  Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
+
+  Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
+  Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
+
+  return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
+                              /*Conditional*/ true, /*hasFinalize*/ true);
+}
+
 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
     BasicBlock *PostInsertBefore, const Twine &Name) {
diff --git a/unittests/Frontend/OpenMPIRBuilderTest.cpp b/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 344c66c..da81367 100644
--- a/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -1788,6 +1788,88 @@
   EXPECT_EQ(MasterEndCI->getArgOperand(1), MasterEntryCI->getArgOperand(1));
 }
 
+TEST_F(OpenMPIRBuilderTest, MaskedDirective) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  AllocaInst *PrivAI = nullptr;
+
+  BasicBlock *EntryBB = nullptr;
+  BasicBlock *ExitBB = nullptr;
+  BasicBlock *ThenBB = nullptr;
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+                       BasicBlock &FiniBB) {
+    if (AllocaIP.isSet())
+      Builder.restoreIP(AllocaIP);
+    else
+      Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
+    PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
+    Builder.CreateStore(F->arg_begin(), PrivAI);
+
+    llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
+    llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
+    EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
+
+    Builder.restoreIP(CodeGenIP);
+
+    // collect some info for checks later
+    ExitBB = FiniBB.getUniqueSuccessor();
+    ThenBB = Builder.GetInsertBlock();
+    EntryBB = ThenBB->getUniquePredecessor();
+
+    // simple instructions for body
+    Value *PrivLoad =
+        Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
+    Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
+  };
+
+  auto FiniCB = [&](InsertPointTy IP) {
+    BasicBlock *IPBB = IP.getBlock();
+    EXPECT_NE(IPBB->end(), IP.getPoint());
+  };
+
+  Constant *Filter = ConstantInt::get(Type::getInt32Ty(M->getContext()), 0);
+  Builder.restoreIP(
+      OMPBuilder.createMasked(Builder, BodyGenCB, FiniCB, Filter));
+  Value *EntryBBTI = EntryBB->getTerminator();
+  EXPECT_NE(EntryBBTI, nullptr);
+  EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
+  BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
+  EXPECT_TRUE(EntryBr->isConditional());
+  EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
+  EXPECT_EQ(ThenBB->getUniqueSuccessor(), ExitBB);
+  EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
+
+  CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
+  EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
+
+  CallInst *MaskedEntryCI = cast<CallInst>(CondInst->getOperand(0));
+  EXPECT_EQ(MaskedEntryCI->getNumArgOperands(), 3U);
+  EXPECT_EQ(MaskedEntryCI->getCalledFunction()->getName(), "__kmpc_masked");
+  EXPECT_TRUE(isa<GlobalVariable>(MaskedEntryCI->getArgOperand(0)));
+
+  CallInst *MaskedEndCI = nullptr;
+  for (auto &FI : *ThenBB) {
+    Instruction *cur = &FI;
+    if (isa<CallInst>(cur)) {
+      MaskedEndCI = cast<CallInst>(cur);
+      if (MaskedEndCI->getCalledFunction()->getName() == "__kmpc_end_masked")
+        break;
+      MaskedEndCI = nullptr;
+    }
+  }
+  EXPECT_NE(MaskedEndCI, nullptr);
+  EXPECT_EQ(MaskedEndCI->getNumArgOperands(), 2U);
+  EXPECT_TRUE(isa<GlobalVariable>(MaskedEndCI->getArgOperand(0)));
+  EXPECT_EQ(MaskedEndCI->getArgOperand(1), MaskedEntryCI->getArgOperand(1));
+}
+
 TEST_F(OpenMPIRBuilderTest, CriticalDirective) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);