[ORC-RT][MachO] Unlock JDStatesMutex during push-initializers to avoid deadlock.

During __orc_rt_macho_jit_dlopen the ORC runtime will make a request to the JIT
to push any new initializers. Since this call may add new JD-state to the
runtime (and is expected to in general) we need to unlock the JDStatesMutex
during this operation (and similarly when running initializers and atexits, as
these may call trigger push-initializers recursively).

No testcase yet: I haven't been able to reproduce the deadlock when running
llvm-jitlink in in-process mode, and we don't support out-of-process mode in
regression tests yet.

GitOrigin-RevId: e7707464a3f7bf0c6011809d80d0a6e525be184b
diff --git a/lib/orc/macho_platform.cpp b/lib/orc/macho_platform.cpp
index e99fe87..dd992ac 100644
--- a/lib/orc/macho_platform.cpp
+++ b/lib/orc/macho_platform.cpp
@@ -138,6 +138,9 @@
     /// Returns true if there are new sections to process.
     bool hasNewSections() const { return !New.empty(); }
 
+    /// Returns the number of new sections to process.
+    size_t numNewSections() const { return New.size(); }
+
     /// Process all new sections.
     template <typename ProcessSectionFunc>
     std::enable_if_t<std::is_void_v<
@@ -266,7 +269,8 @@
   void *dlsym(void *DSOHandle, std::string_view Symbol);
 
   int registerAtExit(void (*F)(void *), void *Arg, void *DSOHandle);
-  void runAtExits(JITDylibState &JDS);
+  void runAtExits(std::unique_lock<std::mutex> &JDStatesLock,
+                  JITDylibState &JDS);
   void runAtExits(void *DSOHandle);
 
   /// Returns the base address of the section containing ThreadData.
@@ -285,24 +289,34 @@
   static Error registerSwift5Protocols(JITDylibState &JDS);
   static Error registerSwift5ProtocolConformances(JITDylibState &JDS);
   static Error registerSwift5Types(JITDylibState &JDS);
-  static Error runModInits(JITDylibState &JDS);
+  static Error runModInits(std::unique_lock<std::mutex> &JDStatesLock,
+                           JITDylibState &JDS);
 
   Expected<void *> dlopenImpl(std::string_view Path, int Mode);
-  Error dlopenFull(JITDylibState &JDS);
-  Error dlopenInitialize(JITDylibState &JDS, MachOJITDylibDepInfoMap &DepInfo);
+  Error dlopenFull(std::unique_lock<std::mutex> &JDStatesLock,
+                   JITDylibState &JDS);
+  Error dlopenInitialize(std::unique_lock<std::mutex> &JDStatesLock,
+                         JITDylibState &JDS, MachOJITDylibDepInfoMap &DepInfo);
 
   Error dlcloseImpl(void *DSOHandle);
-  Error dlcloseDeinitialize(JITDylibState &JDS);
+  Error dlcloseDeinitialize(std::unique_lock<std::mutex> &JDStatesLock,
+                            JITDylibState &JDS);
 
   static MachOPlatformRuntimeState *MOPS;
 
   // FIXME: Move to thread-state.
   std::string DLFcnError;
 
-  std::recursive_mutex JDStatesMutex;
+  // APIMutex guards against concurrent entry into key "dyld" API functions
+  // (e.g. dlopen, dlclose).
+  std::recursive_mutex DyldAPIMutex;
+
+  // JDStatesMutex guards the data structures that hold JITDylib state.
+  std::mutex JDStatesMutex;
   std::unordered_map<void *, JITDylibState> JDStates;
   std::unordered_map<std::string_view, void *> JDNameToHeader;
 
+  // ThreadDataSectionsMutex guards thread local data section state.
   std::mutex ThreadDataSectionsMutex;
   std::map<const char *, size_t> ThreadDataSections;
 };
@@ -329,7 +343,7 @@
   ORC_RT_DEBUG({
     printdbg("Registering JITDylib %s: Header = %p\n", Name.c_str(), Header);
   });
-  std::lock_guard<std::recursive_mutex> Lock(JDStatesMutex);
+  std::lock_guard<std::mutex> Lock(JDStatesMutex);
   if (JDStates.count(Header)) {
     std::ostringstream ErrStream;
     ErrStream << "Duplicate JITDylib registration for header " << Header
@@ -351,7 +365,7 @@
 }
 
 Error MachOPlatformRuntimeState::deregisterJITDylib(void *Header) {
-  std::lock_guard<std::recursive_mutex> Lock(JDStatesMutex);
+  std::lock_guard<std::mutex> Lock(JDStatesMutex);
   auto I = JDStates.find(Header);
   if (I == JDStates.end()) {
     std::ostringstream ErrStream;
@@ -403,12 +417,16 @@
 Error MachOPlatformRuntimeState::registerObjectPlatformSections(
     ExecutorAddr HeaderAddr,
     std::vector<std::pair<std::string_view, ExecutorAddrRange>> Secs) {
+
+  // FIXME: Reject platform section registration after the JITDylib is
+  // sealed?
+
   ORC_RT_DEBUG({
     printdbg("MachOPlatform: Registering object sections for %p.\n",
              HeaderAddr.toPtr<void *>());
   });
 
-  std::lock_guard<std::recursive_mutex> Lock(JDStatesMutex);
+  std::lock_guard<std::mutex> Lock(JDStatesMutex);
   auto *JDS = getJITDylibStateByHeader(HeaderAddr.toPtr<void *>());
   if (!JDS) {
     std::ostringstream ErrStream;
@@ -468,7 +486,7 @@
              HeaderAddr.toPtr<void *>());
   });
 
-  std::lock_guard<std::recursive_mutex> Lock(JDStatesMutex);
+  std::lock_guard<std::mutex> Lock(JDStatesMutex);
   auto *JDS = getJITDylibStateByHeader(HeaderAddr.toPtr<void *>());
   if (!JDS) {
     std::ostringstream ErrStream;
@@ -525,7 +543,7 @@
     std::string S(Path.data(), Path.size());
     printdbg("MachOPlatform::dlopen(\"%s\")\n", S.c_str());
   });
-  std::lock_guard<std::recursive_mutex> Lock(JDStatesMutex);
+  std::lock_guard<std::recursive_mutex> Lock(DyldAPIMutex);
   if (auto H = dlopenImpl(Path, Mode))
     return *H;
   else {
@@ -546,7 +564,7 @@
       printdbg("MachOPlatform::dlclose(%p) (%s)\n", DSOHandle,
                "invalid handle");
   });
-  std::lock_guard<std::recursive_mutex> Lock(JDStatesMutex);
+  std::lock_guard<std::recursive_mutex> Lock(DyldAPIMutex);
   if (auto Err = dlcloseImpl(DSOHandle)) {
     // FIXME: Make dlerror thread safe.
     DLFcnError = toString(std::move(Err));
@@ -569,7 +587,7 @@
 int MachOPlatformRuntimeState::registerAtExit(void (*F)(void *), void *Arg,
                                               void *DSOHandle) {
   // FIXME: Handle out-of-memory errors, returning -1 if OOM.
-  std::lock_guard<std::recursive_mutex> Lock(JDStatesMutex);
+  std::lock_guard<std::mutex> Lock(JDStatesMutex);
   auto *JDS = getJITDylibStateByHeader(DSOHandle);
   if (!JDS) {
     ORC_RT_DEBUG({
@@ -583,16 +601,23 @@
   return 0;
 }
 
-void MachOPlatformRuntimeState::runAtExits(JITDylibState &JDS) {
-  while (!JDS.AtExits.empty()) {
-    auto &AE = JDS.AtExits.back();
+void MachOPlatformRuntimeState::runAtExits(
+    std::unique_lock<std::mutex> &JDStatesLock, JITDylibState &JDS) {
+  auto AtExits = std::move(JDS.AtExits);
+
+  // Unlock while running atexits, as they may trigger operations that modify
+  // JDStates.
+  JDStatesLock.unlock();
+  while (!AtExits.empty()) {
+    auto &AE = AtExits.back();
     AE.Func(AE.Arg);
-    JDS.AtExits.pop_back();
+    AtExits.pop_back();
   }
+  JDStatesLock.lock();
 }
 
 void MachOPlatformRuntimeState::runAtExits(void *DSOHandle) {
-  std::lock_guard<std::recursive_mutex> Lock(JDStatesMutex);
+  std::unique_lock<std::mutex> Lock(JDStatesMutex);
   auto *JDS = getJITDylibStateByHeader(DSOHandle);
   ORC_RT_DEBUG({
     printdbg("MachOPlatformRuntimeState::runAtExits called on unrecognized "
@@ -600,7 +625,7 @@
              DSOHandle);
   });
   if (JDS)
-    runAtExits(*JDS);
+    runAtExits(Lock, *JDS);
 }
 
 Expected<std::pair<const char *, size_t>>
@@ -761,16 +786,31 @@
   return Error::success();
 }
 
-Error MachOPlatformRuntimeState::runModInits(JITDylibState &JDS) {
-  JDS.ModInitsSections.processNewSections([](span<void (*)()> Inits) {
-    for (auto *Init : Inits)
+Error MachOPlatformRuntimeState::runModInits(
+    std::unique_lock<std::mutex> &JDStatesLock, JITDylibState &JDS) {
+  std::vector<span<void (*)()>> InitSections;
+  InitSections.reserve(JDS.ModInitsSections.numNewSections());
+
+  // Copy initializer sections: If the JITDylib is unsealed then the
+  // initializers could reach back into the JIT and cause more initializers to
+  // be added.
+  // FIXME: Skip unlock and run in-place on sealed JITDylibs?
+  JDS.ModInitsSections.processNewSections(
+      [&](span<void (*)()> Inits) { InitSections.push_back(Inits); });
+
+  JDStatesLock.unlock();
+  for (auto InitSec : InitSections)
+    for (auto *Init : InitSec)
       Init();
-  });
+  JDStatesLock.lock();
+
   return Error::success();
 }
 
 Expected<void *> MachOPlatformRuntimeState::dlopenImpl(std::string_view Path,
                                                        int Mode) {
+  std::unique_lock<std::mutex> Lock(JDStatesMutex);
+
   // Try to find JITDylib state by name.
   auto *JDS = getJITDylibStateByName(Path);
 
@@ -782,7 +822,7 @@
   // full dlopen path (update deps, push and run initializers, update ref
   // counts on all JITDylibs in the dep tree).
   if (!JDS->referenced() || !JDS->Sealed) {
-    if (auto Err = dlopenFull(*JDS))
+    if (auto Err = dlopenFull(Lock, *JDS))
       return std::move(Err);
   }
 
@@ -793,17 +833,22 @@
   return JDS->Header;
 }
 
-Error MachOPlatformRuntimeState::dlopenFull(JITDylibState &JDS) {
+Error MachOPlatformRuntimeState::dlopenFull(
+    std::unique_lock<std::mutex> &JDStatesLock, JITDylibState &JDS) {
   // Call back to the JIT to push the initializers.
   Expected<MachOJITDylibDepInfoMap> DepInfo((MachOJITDylibDepInfoMap()));
+  // Unlock so that we can accept the initializer update.
+  JDStatesLock.unlock();
   if (auto Err = WrapperFunction<SPSExpected<SPSMachOJITDylibDepInfoMap>(
           SPSExecutorAddr)>::call(&__orc_rt_macho_push_initializers_tag,
                                   DepInfo, ExecutorAddr::fromPtr(JDS.Header)))
     return Err;
+  JDStatesLock.lock();
+
   if (!DepInfo)
     return DepInfo.takeError();
 
-  if (auto Err = dlopenInitialize(JDS, *DepInfo))
+  if (auto Err = dlopenInitialize(JDStatesLock, JDS, *DepInfo))
     return Err;
 
   if (!DepInfo->empty()) {
@@ -822,7 +867,8 @@
 }
 
 Error MachOPlatformRuntimeState::dlopenInitialize(
-    JITDylibState &JDS, MachOJITDylibDepInfoMap &DepInfo) {
+    std::unique_lock<std::mutex> &JDStatesLock, JITDylibState &JDS,
+    MachOJITDylibDepInfoMap &DepInfo) {
   ORC_RT_DEBUG({
     printdbg("MachOPlatformRuntimeState::dlopenInitialize(\"%s\")\n",
              JDS.Name.c_str());
@@ -868,7 +914,7 @@
       return make_error<StringError>(ErrStream.str());
     }
     ++DepJDS->LinkedAgainstRefCount;
-    if (auto Err = dlopenInitialize(*DepJDS, DepInfo))
+    if (auto Err = dlopenInitialize(JDStatesLock, *DepJDS, DepInfo))
       return Err;
   }
 
@@ -883,7 +929,7 @@
     return Err;
   if (auto Err = registerSwift5Types(JDS))
     return Err;
-  if (auto Err = runModInits(JDS))
+  if (auto Err = runModInits(JDStatesLock, JDS))
     return Err;
 
   // Decrement old deps.
@@ -892,7 +938,7 @@
   for (auto *DepJDS : OldDeps) {
     --DepJDS->LinkedAgainstRefCount;
     if (!DepJDS->referenced())
-      if (auto Err = dlcloseDeinitialize(*DepJDS))
+      if (auto Err = dlcloseDeinitialize(JDStatesLock, *DepJDS))
         return Err;
   }
 
@@ -900,6 +946,8 @@
 }
 
 Error MachOPlatformRuntimeState::dlcloseImpl(void *DSOHandle) {
+  std::unique_lock<std::mutex> Lock(JDStatesMutex);
+
   // Try to find JITDylib state by header.
   auto *JDS = getJITDylibStateByHeader(DSOHandle);
 
@@ -913,19 +961,20 @@
   --JDS->DlRefCount;
 
   if (!JDS->referenced())
-    return dlcloseDeinitialize(*JDS);
+    return dlcloseDeinitialize(Lock, *JDS);
 
   return Error::success();
 }
 
-Error MachOPlatformRuntimeState::dlcloseDeinitialize(JITDylibState &JDS) {
+Error MachOPlatformRuntimeState::dlcloseDeinitialize(
+    std::unique_lock<std::mutex> &JDStatesLock, JITDylibState &JDS) {
 
   ORC_RT_DEBUG({
     printdbg("MachOPlatformRuntimeState::dlcloseDeinitialize(\"%s\")\n",
              JDS.Name.c_str());
   });
 
-  runAtExits(JDS);
+  runAtExits(JDStatesLock, JDS);
 
   // Reset mod-inits
   JDS.ModInitsSections.reset();
@@ -940,7 +989,7 @@
   for (auto *DepJDS : JDS.Deps) {
     --DepJDS->LinkedAgainstRefCount;
     if (!DepJDS->referenced())
-      if (auto Err = dlcloseDeinitialize(*DepJDS))
+      if (auto Err = dlcloseDeinitialize(JDStatesLock, *DepJDS))
         return Err;
   }