[flang][runtime] Catch & report attempts at recursive I/O

When an I/O statement contains a function call that attempts
to perform I/O on the same unit, detect the recursive I/O
and terminate with a useful message rather than deadlocking in
the threading library.

Differential Revision: https://reviews.llvm.org/D131097
diff --git a/flang/runtime/io-api.cpp b/flang/runtime/io-api.cpp
index 60e52e5..bdde576 100644
--- a/flang/runtime/io-api.cpp
+++ b/flang/runtime/io-api.cpp
@@ -213,10 +213,10 @@
     }
     if (iostat == IostatOk) {
       return &unit->BeginIoStatement<STATE<DIR>>(
-          std::forward<A>(xs)..., *unit, sourceFile, sourceLine);
+          terminator, std::forward<A>(xs)..., *unit, sourceFile, sourceLine);
     } else {
       return &unit->BeginIoStatement<ErroneousIoStatementState>(
-          iostat, unit, sourceFile, sourceLine);
+          terminator, iostat, unit, sourceFile, sourceLine);
     }
   }
 }
@@ -270,10 +270,10 @@
     }
     if (iostat == IostatOk) {
       return &unit->BeginIoStatement<ExternalFormattedIoStatementState<DIR>>(
-          *unit, format, formatLength, sourceFile, sourceLine);
+          terminator, *unit, format, formatLength, sourceFile, sourceLine);
     } else {
       return &unit->BeginIoStatement<ErroneousIoStatementState>(
-          iostat, unit, sourceFile, sourceLine);
+          terminator, iostat, unit, sourceFile, sourceLine);
     }
   }
 }
@@ -327,7 +327,7 @@
     if (iostat == IostatOk) {
       IoStatementState &io{
           unit->BeginIoStatement<ExternalUnformattedIoStatementState<DIR>>(
-              *unit, sourceFile, sourceLine)};
+              terminator, *unit, sourceFile, sourceLine)};
       if constexpr (DIR == Direction::Output) {
         if (unit->access == Access::Sequential) {
           // Create space for (sub)record header to be completed by
@@ -339,7 +339,7 @@
       return &io;
     } else {
       return &unit->BeginIoStatement<ErroneousIoStatementState>(
-          iostat, unit, sourceFile, sourceLine);
+          terminator, iostat, unit, sourceFile, sourceLine);
     }
   }
 }
@@ -364,7 +364,7 @@
       unit{ExternalFileUnit::LookUpOrCreate(
           unitNumber, terminator, wasExtant)}) {
     return &unit->BeginIoStatement<OpenStatementState>(
-        *unit, wasExtant, sourceFile, sourceLine);
+        terminator, *unit, wasExtant, sourceFile, sourceLine);
   } else {
     return NoopUnit(terminator, unitNumber, IostatBadUnitNumber);
   }
@@ -376,21 +376,21 @@
   ExternalFileUnit &unit{
       ExternalFileUnit::NewUnit(terminator, false /*not child I/O*/)};
   return &unit.BeginIoStatement<OpenStatementState>(
-      unit, false /*was an existing file*/, sourceFile, sourceLine);
+      terminator, unit, false /*was an existing file*/, sourceFile, sourceLine);
 }
 
 Cookie IONAME(BeginWait)(ExternalUnit unitNumber, AsynchronousId id,
     const char *sourceFile, int sourceLine) {
+  Terminator terminator{sourceFile, sourceLine};
   if (ExternalFileUnit * unit{ExternalFileUnit::LookUp(unitNumber)}) {
     if (unit->Wait(id)) {
-      return &unit->BeginIoStatement<ExternalMiscIoStatementState>(
+      return &unit->BeginIoStatement<ExternalMiscIoStatementState>(terminator,
           *unit, ExternalMiscIoStatementState::Wait, sourceFile, sourceLine);
     } else {
       return &unit->BeginIoStatement<ErroneousIoStatementState>(
-          IostatBadWaitId, unit, sourceFile, sourceLine);
+          terminator, IostatBadWaitId, unit, sourceFile, sourceLine);
     }
   } else {
-    Terminator terminator{sourceFile, sourceLine};
     return NoopUnit(
         terminator, unitNumber, id == 0 ? IostatOk : IostatBadWaitUnit);
   }
@@ -402,24 +402,24 @@
 
 Cookie IONAME(BeginClose)(
     ExternalUnit unitNumber, const char *sourceFile, int sourceLine) {
+  Terminator terminator{sourceFile, sourceLine};
   if (ExternalFileUnit * unit{ExternalFileUnit::LookUpForClose(unitNumber)}) {
     return &unit->BeginIoStatement<CloseStatementState>(
-        *unit, sourceFile, sourceLine);
+        terminator, *unit, sourceFile, sourceLine);
   } else {
     // CLOSE(UNIT=bad unit) is just a no-op
-    Terminator terminator{sourceFile, sourceLine};
     return NoopUnit(terminator, unitNumber);
   }
 }
 
 Cookie IONAME(BeginFlush)(
     ExternalUnit unitNumber, const char *sourceFile, int sourceLine) {
+  Terminator terminator{sourceFile, sourceLine};
   if (ExternalFileUnit * unit{ExternalFileUnit::LookUp(unitNumber)}) {
-    return &unit->BeginIoStatement<ExternalMiscIoStatementState>(
+    return &unit->BeginIoStatement<ExternalMiscIoStatementState>(terminator,
         *unit, ExternalMiscIoStatementState::Flush, sourceFile, sourceLine);
   } else {
     // FLUSH(UNIT=bad unit) is an error; an unconnected unit is a no-op
-    Terminator terminator{sourceFile, sourceLine};
     return NoopUnit(terminator, unitNumber,
         unitNumber >= 0 ? IostatOk : IostatBadFlushUnit);
   }
@@ -429,7 +429,7 @@
     ExternalUnit unitNumber, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
   if (ExternalFileUnit * unit{ExternalFileUnit::LookUp(unitNumber)}) {
-    return &unit->BeginIoStatement<ExternalMiscIoStatementState>(
+    return &unit->BeginIoStatement<ExternalMiscIoStatementState>(terminator,
         *unit, ExternalMiscIoStatementState::Backspace, sourceFile, sourceLine);
   } else {
     return NoopUnit(terminator, unitNumber, IostatBadBackspaceUnit);
@@ -443,7 +443,7 @@
   if (ExternalFileUnit *
       unit{GetOrCreateUnit(unitNumber, Direction::Output, std::nullopt,
           terminator, errorCookie)}) {
-    return &unit->BeginIoStatement<ExternalMiscIoStatementState>(
+    return &unit->BeginIoStatement<ExternalMiscIoStatementState>(terminator,
         *unit, ExternalMiscIoStatementState::Endfile, sourceFile, sourceLine);
   } else {
     return errorCookie;
@@ -457,7 +457,7 @@
   if (ExternalFileUnit *
       unit{GetOrCreateUnit(unitNumber, Direction::Input, std::nullopt,
           terminator, errorCookie)}) {
-    return &unit->BeginIoStatement<ExternalMiscIoStatementState>(
+    return &unit->BeginIoStatement<ExternalMiscIoStatementState>(terminator,
         *unit, ExternalMiscIoStatementState::Rewind, sourceFile, sourceLine);
   } else {
     return errorCookie;
@@ -466,18 +466,19 @@
 
 Cookie IONAME(BeginInquireUnit)(
     ExternalUnit unitNumber, const char *sourceFile, int sourceLine) {
+  Terminator terminator{sourceFile, sourceLine};
   if (ExternalFileUnit * unit{ExternalFileUnit::LookUp(unitNumber)}) {
     if (ChildIo * child{unit->GetChildIo()}) {
       return &child->BeginIoStatement<InquireUnitState>(
           *unit, sourceFile, sourceLine);
     } else {
       return &unit->BeginIoStatement<InquireUnitState>(
-          *unit, sourceFile, sourceLine);
+          terminator, *unit, sourceFile, sourceLine);
     }
   } else {
     // INQUIRE(UNIT=unrecognized unit)
-    Terminator oom{sourceFile, sourceLine};
-    return &New<InquireNoUnitState>{oom}(sourceFile, sourceLine, unitNumber)
+    return &New<InquireNoUnitState>{terminator}(
+        sourceFile, sourceLine, unitNumber)
                 .release()
                 ->ioStatementState();
   }
@@ -485,17 +486,17 @@
 
 Cookie IONAME(BeginInquireFile)(const char *path, std::size_t pathLength,
     const char *sourceFile, int sourceLine) {
-  Terminator oom{sourceFile, sourceLine};
-  auto trimmed{
-      SaveDefaultCharacter(path, TrimTrailingSpaces(path, pathLength), oom)};
+  Terminator terminator{sourceFile, sourceLine};
+  auto trimmed{SaveDefaultCharacter(
+      path, TrimTrailingSpaces(path, pathLength), terminator)};
   if (ExternalFileUnit *
       unit{ExternalFileUnit::LookUp(
           trimmed.get(), std::strlen(trimmed.get()))}) {
     // INQUIRE(FILE=) to a connected unit
     return &unit->BeginIoStatement<InquireUnitState>(
-        *unit, sourceFile, sourceLine);
+        terminator, *unit, sourceFile, sourceLine);
   } else {
-    return &New<InquireUnconnectedFileState>{oom}(
+    return &New<InquireUnconnectedFileState>{terminator}(
         std::move(trimmed), sourceFile, sourceLine)
                 .release()
                 ->ioStatementState();
diff --git a/flang/runtime/unit.cpp b/flang/runtime/unit.cpp
index 12d8dbc..ac6311f0 100644
--- a/flang/runtime/unit.cpp
+++ b/flang/runtime/unit.cpp
@@ -698,7 +698,8 @@
 void ExternalFileUnit::EndIoStatement() {
   io_.reset();
   u_.emplace<std::monostate>();
-  lock_.Drop();
+  CriticalSection critical{lock_};
+  isBusy_ = false;
 }
 
 void ExternalFileUnit::BeginSequentialVariableUnformattedInputRecord(
diff --git a/flang/runtime/unit.h b/flang/runtime/unit.h
index 76666c6..a456a49 100644
--- a/flang/runtime/unit.h
+++ b/flang/runtime/unit.h
@@ -70,8 +70,18 @@
   Iostat SetDirection(Direction);
 
   template <typename A, typename... X>
-  IoStatementState &BeginIoStatement(X &&...xs) {
-    lock_.Take(); // dropped in EndIoStatement()
+  IoStatementState &BeginIoStatement(const Terminator &terminator, X &&...xs) {
+    bool alreadyBusy{false};
+    {
+      CriticalSection critical{lock_};
+      alreadyBusy = isBusy_;
+      isBusy_ = true; // cleared in EndIoStatement()
+    }
+    if (alreadyBusy) {
+      terminator.Crash("Could not acquire exclusive lock on unit %d, perhaps "
+                       "due to an attempt to perform recursive I/O",
+          unitNumber_);
+    }
     A &state{u_.emplace<A>(std::forward<X>(xs)...)};
     if constexpr (!std::is_same_v<A, OpenStatementState>) {
       state.mutableModes() = ConnectionState::modes;
@@ -125,6 +135,8 @@
   void HitEndOnRead(IoErrorHandler &);
 
   Lock lock_;
+  // TODO: replace with a thread ID
+  bool isBusy_{false}; // under lock_
 
   int unitNumber_{-1};
   Direction direction_{Direction::Output};