[flang] Search for #include "file" in right directory

Make the #include "file" preprocessing directive begin its
search in the same directory as the file containing the directive,
as other preprocessors and our Fortran INCLUDE statement do.

Avoid current working directory for all source files after the original.

Differential Revision: https://reviews.llvm.org/D95388

GitOrigin-RevId: d987b61b1dce9948801ac37704477e7c257100b1
diff --git a/include/flang/Parser/provenance.h b/include/flang/Parser/provenance.h
index 73661d9..08afcf1 100644
--- a/include/flang/Parser/provenance.h
+++ b/include/flang/Parser/provenance.h
@@ -148,9 +148,9 @@
     return *this;
   }
 
-  void PushSearchPathDirectory(std::string);
-  std::string PopSearchPathDirectory();
-  const SourceFile *Open(std::string path, llvm::raw_ostream &error);
+  void AppendSearchPathDirectory(std::string); // new last directory
+  const SourceFile *Open(std::string path, llvm::raw_ostream &error,
+      const std::optional<std::string> &prependPath);
   const SourceFile *ReadStandardInput(llvm::raw_ostream &error);
 
   ProvenanceRange AddIncludedFile(
@@ -210,7 +210,7 @@
   ProvenanceRange range_;
   std::map<char, Provenance> compilerInsertionProvenance_;
   std::vector<std::unique_ptr<SourceFile>> ownedSourceFiles_;
-  std::vector<std::string> searchPath_;
+  std::list<std::string> searchPath_;
   Encoding encoding_{Encoding::UTF_8};
 };
 
diff --git a/include/flang/Parser/source.h b/include/flang/Parser/source.h
index e0d1a53..4f387bd 100644
--- a/include/flang/Parser/source.h
+++ b/include/flang/Parser/source.h
@@ -17,6 +17,8 @@
 #include "characters.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include <cstddef>
+#include <list>
+#include <optional>
 #include <string>
 #include <utility>
 #include <vector>
@@ -28,8 +30,8 @@
 namespace Fortran::parser {
 
 std::string DirectoryName(std::string path);
-std::string LocateSourceFile(
-    std::string name, const std::vector<std::string> &searchPath);
+std::optional<std::string> LocateSourceFile(
+    std::string name, const std::list<std::string> &searchPath);
 
 class SourceFile;
 
diff --git a/lib/Parser/parsing.cpp b/lib/Parser/parsing.cpp
index 7f3a4a6..0bd542c 100644
--- a/lib/Parser/parsing.cpp
+++ b/lib/Parser/parsing.cpp
@@ -25,7 +25,7 @@
   AllSources &allSources{allCooked_.allSources()};
   if (options.isModuleFile) {
     for (const auto &path : options.searchDirectories) {
-      allSources.PushSearchPathDirectory(path);
+      allSources.AppendSearchPathDirectory(path);
     }
   }
 
@@ -35,7 +35,8 @@
   if (path == "-") {
     sourceFile = allSources.ReadStandardInput(fileError);
   } else {
-    sourceFile = allSources.Open(path, fileError);
+    std::optional<std::string> currentDirectory{"."};
+    sourceFile = allSources.Open(path, fileError, currentDirectory);
   }
   if (!fileError.str().empty()) {
     ProvenanceRange range{allSources.AddCompilerInsertion(path)};
@@ -46,12 +47,12 @@
 
   if (!options.isModuleFile) {
     // For .mod files we always want to look in the search directories.
-    // For normal source files we don't push them until after the primary
+    // For normal source files we don't add them until after the primary
     // source file has been opened.  If foo.f is missing from the current
     // working directory, we don't want to accidentally read another foo.f
     // from another directory that's on the search path.
     for (const auto &path : options.searchDirectories) {
-      allSources.PushSearchPathDirectory(path);
+      allSources.AppendSearchPathDirectory(path);
     }
   }
 
diff --git a/lib/Parser/preprocessor.cpp b/lib/Parser/preprocessor.cpp
index c5422cc..14c9e54 100644
--- a/lib/Parser/preprocessor.cpp
+++ b/lib/Parser/preprocessor.cpp
@@ -399,6 +399,7 @@
   if (j == tokens) {
     return;
   }
+  CHECK(prescanner); // TODO: change to reference
   if (dir.TokenAt(j).ToString() != "#") {
     prescanner->Say(dir.GetTokenProvenanceRange(j), "missing '#'"_err_en_US);
     return;
@@ -578,6 +579,7 @@
       return;
     }
     std::string include;
+    std::optional<std::string> prependPath;
     if (dir.TokenAt(j).ToString() == "<") { // #include <foo>
       std::size_t k{j + 1};
       if (k >= tokens) {
@@ -598,6 +600,12 @@
     } else if ((include = dir.TokenAt(j).ToString()).substr(0, 1) == "\"" &&
         include.substr(include.size() - 1, 1) == "\"") { // #include "foo"
       include = include.substr(1, include.size() - 2);
+      // #include "foo" starts search in directory of file containing
+      // the directive
+      auto prov{dir.GetTokenProvenanceRange(dirOffset).start()};
+      if (const auto *currentFile{allSources_.GetSourceFile(prov)}) {
+        prependPath = DirectoryName(currentFile->path());
+      }
     } else {
       prescanner->Say(dir.GetTokenProvenanceRange(j < tokens ? j : tokens - 1),
           "#include: expected name of file to include"_err_en_US);
@@ -615,7 +623,7 @@
     }
     std::string buf;
     llvm::raw_string_ostream error{buf};
-    const SourceFile *included{allSources_.Open(include, error)};
+    const SourceFile *included{allSources_.Open(include, error, prependPath)};
     if (!included) {
       prescanner->Say(dir.GetTokenProvenanceRange(dirOffset),
           "#include: %s"_err_en_US, error.str());
diff --git a/lib/Parser/prescan.cpp b/lib/Parser/prescan.cpp
index dc6fbe5..7a78dd2 100644
--- a/lib/Parser/prescan.cpp
+++ b/lib/Parser/prescan.cpp
@@ -760,14 +760,11 @@
   std::string buf;
   llvm::raw_string_ostream error{buf};
   Provenance provenance{GetProvenance(nextLine_)};
-  const SourceFile *currentFile{allSources_.GetSourceFile(provenance)};
-  if (currentFile) {
-    allSources_.PushSearchPathDirectory(DirectoryName(currentFile->path()));
+  std::optional<std::string> prependPath;
+  if (const SourceFile * currentFile{allSources_.GetSourceFile(provenance)}) {
+    prependPath = DirectoryName(currentFile->path());
   }
-  const SourceFile *included{allSources_.Open(path, error)};
-  if (currentFile) {
-    allSources_.PopSearchPathDirectory();
-  }
+  const SourceFile *included{allSources_.Open(path, error, prependPath)};
   if (!included) {
     Say(provenance, "INCLUDE: %s"_err_en_US, error.str());
   } else if (included->bytes() > 0) {
diff --git a/lib/Parser/provenance.cpp b/lib/Parser/provenance.cpp
index 46a0dc9..bed9090 100644
--- a/lib/Parser/provenance.cpp
+++ b/lib/Parser/provenance.cpp
@@ -156,20 +156,28 @@
   return origin[origin.covers.MemberOffset(at)];
 }
 
-void AllSources::PushSearchPathDirectory(std::string directory) {
+void AllSources::AppendSearchPathDirectory(std::string directory) {
   // gfortran and ifort append to current path, PGI prepends
   searchPath_.push_back(directory);
 }
 
-std::string AllSources::PopSearchPathDirectory() {
-  std::string directory{searchPath_.back()};
-  searchPath_.pop_back();
-  return directory;
-}
-
-const SourceFile *AllSources::Open(std::string path, llvm::raw_ostream &error) {
+const SourceFile *AllSources::Open(std::string path, llvm::raw_ostream &error,
+    const std::optional<std::string> &prependPath) {
   std::unique_ptr<SourceFile> source{std::make_unique<SourceFile>(encoding_)};
-  if (source->Open(LocateSourceFile(path, searchPath_), error)) {
+  if (prependPath) {
+    // Set to "." for the initial source file; set to the directory name
+    // of the including file for #include "quoted-file" directives &
+    // INCLUDE statements.
+    searchPath_.push_front(*prependPath);
+  }
+  std::optional<std::string> found{LocateSourceFile(path, searchPath_)};
+  if (prependPath) {
+    searchPath_.pop_front();
+  }
+  if (!found) {
+    error << "Source file '" << path << "' was not found";
+    return nullptr;
+  } else if (source->Open(*found, error)) {
     return ownedSourceFiles_.emplace_back(std::move(source)).get();
   } else {
     return nullptr;
diff --git a/lib/Parser/source.cpp b/lib/Parser/source.cpp
index 11cd591..3fbbf78 100644
--- a/lib/Parser/source.cpp
+++ b/lib/Parser/source.cpp
@@ -56,9 +56,9 @@
   return pathBuf.str().str();
 }
 
-std::string LocateSourceFile(
-    std::string name, const std::vector<std::string> &searchPath) {
-  if (name.empty() || name == "-" || llvm::sys::path::is_absolute(name)) {
+std::optional<std::string> LocateSourceFile(
+    std::string name, const std::list<std::string> &searchPath) {
+  if (name == "-" || llvm::sys::path::is_absolute(name)) {
     return name;
   }
   for (const std::string &dir : searchPath) {
@@ -70,7 +70,7 @@
       return path.str().str();
     }
   }
-  return name;
+  return std::nullopt;
 }
 
 std::size_t RemoveCarriageReturns(llvm::MutableArrayRef<char> buf) {
@@ -123,7 +123,6 @@
 bool SourceFile::ReadStandardInput(llvm::raw_ostream &error) {
   Close();
   path_ = "standard input";
-
   auto buf_or = llvm::MemoryBuffer::getSTDIN();
   if (!buf_or) {
     auto err = buf_or.getError();
@@ -146,7 +145,6 @@
       auto tmp_buf{llvm::WritableMemoryBuffer::getNewUninitMemBuffer(
           content().size() + 1)};
       llvm::copy(content(), tmp_buf->getBufferStart());
-      Close();
       buf_ = std::move(tmp_buf);
     }
     buf_end_++;
diff --git a/tools/f18/f18.cpp b/tools/f18/f18.cpp
index 9a10aed..7cb0129 100644
--- a/tools/f18/f18.cpp
+++ b/tools/f18/f18.cpp
@@ -84,7 +84,7 @@
   bool verbose{false}; // -v
   bool compileOnly{false}; // -c
   std::string outputPath; // -o path
-  std::vector<std::string> searchDirectories{"."s}; // -I dir
+  std::vector<std::string> searchDirectories; // -I dir
   std::string moduleDirectory{"."s}; // -module dir
   std::string moduleFileSuffix{".mod"}; // -moduleSuffix suff
   bool forcedForm{false}; // -Mfixed or -Mfree appeared