[clangd] Fix toHalfOpenFileRange where start/end endpoints are in different files due to #include

Summary: https://github.com/clangd/clangd/issues/129

Reviewers: SureYeaah

Subscribers: ilya-biryukov, MaskRay, jkorous, arphaman, kadircet, cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D66590

git-svn-id: https://llvm.org/svn/llvm-project/clang-tools-extra/trunk@370029 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/clangd/SourceCode.cpp b/clangd/SourceCode.cpp
index 7a2f7c7..0c52dcd 100644
--- a/clangd/SourceCode.cpp
+++ b/clangd/SourceCode.cpp
@@ -264,6 +264,29 @@
   return L == R.getEnd() || halfOpenRangeContains(Mgr, R, L);
 }
 
+SourceLocation includeHashLoc(FileID IncludedFile, const SourceManager &SM) {
+  assert(SM.getLocForEndOfFile(IncludedFile).isFileID());
+  FileID IncludingFile;
+  unsigned Offset;
+  std::tie(IncludingFile, Offset) =
+      SM.getDecomposedExpansionLoc(SM.getIncludeLoc(IncludedFile));
+  bool Invalid = false;
+  llvm::StringRef Buf = SM.getBufferData(IncludingFile, &Invalid);
+  if (Invalid)
+    return SourceLocation();
+  // Now buf is "...\n#include <foo>\n..."
+  // and Offset points here:   ^
+  // Rewind to the preceding # on the line.
+  assert(Offset < Buf.size());
+  for (;; --Offset) {
+    if (Buf[Offset] == '#')
+      return SM.getComposedLoc(IncludingFile, Offset);
+    if (Buf[Offset] == '\n' || Offset == 0) // no hash, what's going on?
+      return SourceLocation();
+  }
+}
+
+
 static unsigned getTokenLengthAtLoc(SourceLocation Loc, const SourceManager &SM,
                                     const LangOptions &LangOpts) {
   Token TheTok;
@@ -308,16 +331,49 @@
 static SourceRange unionTokenRange(SourceRange R1, SourceRange R2,
                                    const SourceManager &SM,
                                    const LangOptions &LangOpts) {
-  SourceLocation E1 = getLocForTokenEnd(R1.getEnd(), SM, LangOpts);
-  SourceLocation E2 = getLocForTokenEnd(R2.getEnd(), SM, LangOpts);
-  return SourceRange(std::min(R1.getBegin(), R2.getBegin()),
-                     E1 < E2 ? R2.getEnd() : R1.getEnd());
+  SourceLocation Begin =
+      SM.isBeforeInTranslationUnit(R1.getBegin(), R2.getBegin())
+          ? R1.getBegin()
+          : R2.getBegin();
+  SourceLocation End =
+      SM.isBeforeInTranslationUnit(getLocForTokenEnd(R1.getEnd(), SM, LangOpts),
+                                   getLocForTokenEnd(R2.getEnd(), SM, LangOpts))
+          ? R2.getEnd()
+          : R1.getEnd();
+  return SourceRange(Begin, End);
 }
 
-// Check if two locations have the same file id.
-static bool inSameFile(SourceLocation Loc1, SourceLocation Loc2,
-                       const SourceManager &SM) {
-  return SM.getFileID(Loc1) == SM.getFileID(Loc2);
+// Given a range whose endpoints may be in different expansions or files,
+// tries to find a range within a common file by following up the expansion and
+// include location in each.
+static SourceRange rangeInCommonFile(SourceRange R, const SourceManager &SM,
+                                     const LangOptions &LangOpts) {
+  // Fast path for most common cases.
+  if (SM.isWrittenInSameFile(R.getBegin(), R.getEnd()))
+    return R;
+  // Record the stack of expansion locations for the beginning, keyed by FileID.
+  llvm::DenseMap<FileID, SourceLocation> BeginExpansions;
+  for (SourceLocation Begin = R.getBegin(); Begin.isValid();
+       Begin = Begin.isFileID()
+                   ? includeHashLoc(SM.getFileID(Begin), SM)
+                   : SM.getImmediateExpansionRange(Begin).getBegin()) {
+    BeginExpansions[SM.getFileID(Begin)] = Begin;
+  }
+  // Move up the stack of expansion locations for the end until we find the
+  // location in BeginExpansions with that has the same file id.
+  for (SourceLocation End = R.getEnd(); End.isValid();
+       End = End.isFileID() ? includeHashLoc(SM.getFileID(End), SM)
+                            : toTokenRange(SM.getImmediateExpansionRange(End),
+                                           SM, LangOpts)
+                                  .getEnd()) {
+    auto It = BeginExpansions.find(SM.getFileID(End));
+    if (It != BeginExpansions.end()) {
+      if (SM.getFileOffset(It->second) > SM.getFileOffset(End))
+        return SourceLocation();
+      return {It->second, End};
+    }
+  }
+  return SourceRange();
 }
 
 // Find an expansion range (not necessarily immediate) the ends of which are in
@@ -325,33 +381,11 @@
 static SourceRange
 getExpansionTokenRangeInSameFile(SourceLocation Loc, const SourceManager &SM,
                                  const LangOptions &LangOpts) {
-  SourceRange ExpansionRange =
-      toTokenRange(SM.getImmediateExpansionRange(Loc), SM, LangOpts);
-  // Fast path for most common cases.
-  if (inSameFile(ExpansionRange.getBegin(), ExpansionRange.getEnd(), SM))
-    return ExpansionRange;
-  // Record the stack of expansion locations for the beginning, keyed by FileID.
-  llvm::DenseMap<FileID, SourceLocation> BeginExpansions;
-  for (SourceLocation Begin = ExpansionRange.getBegin(); Begin.isValid();
-       Begin = Begin.isFileID()
-                   ? SourceLocation()
-                   : SM.getImmediateExpansionRange(Begin).getBegin()) {
-    BeginExpansions[SM.getFileID(Begin)] = Begin;
-  }
-  // Move up the stack of expansion locations for the end until we find the
-  // location in BeginExpansions with that has the same file id.
-  for (SourceLocation End = ExpansionRange.getEnd(); End.isValid();
-       End = End.isFileID() ? SourceLocation()
-                            : toTokenRange(SM.getImmediateExpansionRange(End),
-                                           SM, LangOpts)
-                                  .getEnd()) {
-    auto It = BeginExpansions.find(SM.getFileID(End));
-    if (It != BeginExpansions.end())
-      return {It->second, End};
-  }
-  llvm_unreachable(
-      "We should able to find a common ancestor in the expansion tree.");
+  return rangeInCommonFile(
+      toTokenRange(SM.getImmediateExpansionRange(Loc), SM, LangOpts), SM,
+      LangOpts);
 }
+
 // Returns the file range for a given Location as a Token Range
 // This is quite similar to getFileLoc in SourceManager as both use
 // getImmediateExpansionRange and getImmediateSpellingLoc (for macro IDs).
@@ -371,14 +405,17 @@
       FileRange = unionTokenRange(
           SM.getImmediateSpellingLoc(FileRange.getBegin()),
           SM.getImmediateSpellingLoc(FileRange.getEnd()), SM, LangOpts);
-      assert(inSameFile(FileRange.getBegin(), FileRange.getEnd(), SM));
+      assert(SM.isWrittenInSameFile(FileRange.getBegin(), FileRange.getEnd()));
     } else {
       SourceRange ExpansionRangeForBegin =
           getExpansionTokenRangeInSameFile(FileRange.getBegin(), SM, LangOpts);
       SourceRange ExpansionRangeForEnd =
           getExpansionTokenRangeInSameFile(FileRange.getEnd(), SM, LangOpts);
-      assert(inSameFile(ExpansionRangeForBegin.getBegin(),
-                        ExpansionRangeForEnd.getBegin(), SM) &&
+      if (ExpansionRangeForBegin.isInvalid() ||
+          ExpansionRangeForEnd.isInvalid())
+        return SourceRange();
+      assert(SM.isWrittenInSameFile(ExpansionRangeForBegin.getBegin(),
+                                    ExpansionRangeForEnd.getBegin()) &&
              "Both Expansion ranges should be in same file.");
       FileRange = unionTokenRange(ExpansionRangeForBegin, ExpansionRangeForEnd,
                                   SM, LangOpts);
@@ -402,7 +439,8 @@
   if (!isValidFileRange(SM, R2))
     return llvm::None;
 
-  SourceRange Result = unionTokenRange(R1, R2, SM, LangOpts);
+  SourceRange Result =
+      rangeInCommonFile(unionTokenRange(R1, R2, SM, LangOpts), SM, LangOpts);
   unsigned TokLen = getTokenLengthAtLoc(Result.getEnd(), SM, LangOpts);
   // Convert from closed token range to half-open (char) range
   Result.setEnd(Result.getEnd().getLocWithOffset(TokLen));
diff --git a/clangd/SourceCode.h b/clangd/SourceCode.h
index 6517043..4e20706 100644
--- a/clangd/SourceCode.h
+++ b/clangd/SourceCode.h
@@ -83,6 +83,11 @@
 /// the main file.
 bool isInsideMainFile(SourceLocation Loc, const SourceManager &SM);
 
+/// Returns the #include location through which IncludedFIle was loaded.
+/// Where SM.getIncludeLoc() returns the location of the *filename*, which may
+/// be in a macro, includeHashLoc() returns the location of the #.
+SourceLocation includeHashLoc(FileID IncludedFile, const SourceManager &SM);
+
 /// Returns true if the token at Loc is spelled in the source code.
 /// This is not the case for:
 ///   * symbols formed via macro concatenation, the spelling location will
diff --git a/clangd/unittests/SelectionTests.cpp b/clangd/unittests/SelectionTests.cpp
index f8aba42..a8ca324 100644
--- a/clangd/unittests/SelectionTests.cpp
+++ b/clangd/unittests/SelectionTests.cpp
@@ -346,6 +346,25 @@
   }
 }
 
+TEST(SelectionTest, PathologicalPreprocessor) {
+  const char *Case = R"cpp(
+#define MACRO while(1)
+    void test() {
+#include "Expand.inc"
+        br^eak;
+    }
+  )cpp";
+  Annotations Test(Case);
+  auto TU = TestTU::withCode(Test.code());
+  TU.AdditionalFiles["Expand.inc"] = "MACRO\n";
+  auto AST = TU.build();
+  EXPECT_THAT(AST.getDiagnostics(), ::testing::IsEmpty());
+  auto T = makeSelectionTree(Case, AST);
+
+  EXPECT_EQ("BreakStmt", T.commonAncestor()->kind());
+  EXPECT_EQ("WhileStmt", T.commonAncestor()->Parent->kind());
+}
+
 TEST(SelectionTest, Implicit) {
   const char* Test = R"cpp(
     struct S { S(const char*); };
diff --git a/clangd/unittests/SourceCodeTests.cpp b/clangd/unittests/SourceCodeTests.cpp
index 47d1f80..a216b6b 100644
--- a/clangd/unittests/SourceCodeTests.cpp
+++ b/clangd/unittests/SourceCodeTests.cpp
@@ -11,9 +11,11 @@
 #include "SourceCode.h"
 #include "TestTU.h"
 #include "clang/Basic/LangOptions.h"
+#include "clang/Basic/SourceLocation.h"
 #include "clang/Format/Format.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/raw_os_ostream.h"
+#include "llvm/Testing/Support/Annotations.h"
 #include "llvm/Testing/Support/Error.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
@@ -505,6 +507,53 @@
   CheckRange("f");
 }
 
+TEST(SourceCodeTests, HalfOpenFileRangePathologicalPreprocessor) {
+  const char *Case = R"cpp(
+#define MACRO while(1)
+    void test() {
+[[#include "Expand.inc"
+        br^eak]];
+    }
+  )cpp";
+  Annotations Test(Case);
+  auto TU = TestTU::withCode(Test.code());
+  TU.AdditionalFiles["Expand.inc"] = "MACRO\n";
+  auto AST = TU.build();
+
+  const auto &Func = cast<FunctionDecl>(findDecl(AST, "test"));
+  const auto &Body = cast<CompoundStmt>(Func.getBody());
+  const auto &Loop = cast<WhileStmt>(*Body->child_begin());
+  llvm::Optional<SourceRange> Range = toHalfOpenFileRange(
+      AST.getSourceManager(), AST.getASTContext().getLangOpts(),
+      Loop->getSourceRange());
+  ASSERT_TRUE(Range) << "Failed to get file range";
+  EXPECT_EQ(AST.getSourceManager().getFileOffset(Range->getBegin()),
+            Test.llvm::Annotations::range().Begin);
+  EXPECT_EQ(AST.getSourceManager().getFileOffset(Range->getEnd()),
+            Test.llvm::Annotations::range().End);
+}
+
+TEST(SourceCodeTests, IncludeHashLoc) {
+  const char *Case = R"cpp(
+$foo^#include "foo.inc"
+#define HEADER "bar.inc"
+  $bar^#  include HEADER
+  )cpp";
+  Annotations Test(Case);
+  auto TU = TestTU::withCode(Test.code());
+  TU.AdditionalFiles["foo.inc"] = "int foo;\n";
+  TU.AdditionalFiles["bar.inc"] = "int bar;\n";
+  auto AST = TU.build();
+  const auto& SM = AST.getSourceManager();
+
+  FileID Foo = SM.getFileID(findDecl(AST, "foo").getLocation());
+  EXPECT_EQ(SM.getFileOffset(includeHashLoc(Foo, SM)),
+            Test.llvm::Annotations::point("foo"));
+  FileID Bar = SM.getFileID(findDecl(AST, "bar").getLocation());
+  EXPECT_EQ(SM.getFileOffset(includeHashLoc(Bar, SM)),
+            Test.llvm::Annotations::point("foo"));
+}
+
 } // namespace
 } // namespace clangd
 } // namespace clang