[demangler] Add support for C++20 modules

Add support for module name demangling.  We have two new demangler
nodes -- ModuleName and ModuleEntity. The former represents a module
name in a hierarchical fashion. The latter is the combination of a
(name) node and a module name. Because module names and entity
identities use the same substitution encoding, we have to adjust the
flow of how substitutions are handled, and examine the substituted
node to know how to deal with it.

Reviewed By: dblaikie

Differential Revision: https://reviews.llvm.org/D119933

GitOrigin-RevId: c354167ae217f90399bff9a644fffb3e9a6b4334
diff --git a/src/demangle/ItaniumDemangle.h b/src/demangle/ItaniumDemangle.h
index a3693b5..e084cbf 100644
--- a/src/demangle/ItaniumDemangle.h
+++ b/src/demangle/ItaniumDemangle.h
@@ -16,10 +16,6 @@
 #ifndef DEMANGLE_ITANIUMDEMANGLE_H
 #define DEMANGLE_ITANIUMDEMANGLE_H
 
-// FIXME: (possibly) incomplete list of features that clang mangles that this
-// file does not yet support:
-//   - C++ modules TS
-
 #include "DemangleConfig.h"
 #include "StringView.h"
 #include "Utility.h"
@@ -58,6 +54,8 @@
   X(QualifiedName)                                                             \
   X(NestedName)                                                                \
   X(LocalName)                                                                 \
+  X(ModuleName)                                                                \
+  X(ModuleEntity)                                                              \
   X(VectorType)                                                                \
   X(PixelVectorType)                                                           \
   X(BinaryFPType)                                                              \
@@ -1000,6 +998,44 @@
   }
 };
 
+struct ModuleName : Node {
+  ModuleName *Parent;
+  Node *Name;
+  bool IsPartition;
+
+  ModuleName(ModuleName *Parent_, Node *Name_, bool IsPartition_ = false)
+      : Node(KModuleName), Parent(Parent_), Name(Name_),
+        IsPartition(IsPartition_) {}
+
+  template <typename Fn> void match(Fn F) const { F(Parent, Name); }
+
+  void printLeft(OutputBuffer &OB) const override {
+    if (Parent)
+      Parent->print(OB);
+    if (Parent || IsPartition)
+      OB += IsPartition ? ':' : '.';
+    Name->print(OB);
+  }
+};
+
+struct ModuleEntity : Node {
+  ModuleName *Module;
+  Node *Name;
+
+  ModuleEntity(ModuleName *Module_, Node *Name_)
+      : Node(KModuleEntity), Module(Module_), Name(Name_) {}
+
+  template <typename Fn> void match(Fn F) const { F(Module, Name); }
+
+  StringView getBaseName() const override { return Name->getBaseName(); }
+
+  void printLeft(OutputBuffer &OB) const override {
+    Name->print(OB);
+    OB += '@';
+    Module->print(OB);
+  }
+};
+
 struct LocalName : Node {
   Node *Encoding;
   Node *Entity;
@@ -2547,10 +2583,11 @@
   Node *parseName(NameState *State = nullptr);
   Node *parseLocalName(NameState *State);
   Node *parseOperatorName(NameState *State);
-  Node *parseUnqualifiedName(NameState *State, Node *Scope);
+  bool parseModuleNameOpt(ModuleName *&Module);
+  Node *parseUnqualifiedName(NameState *State, Node *Scope, ModuleName *Module);
   Node *parseUnnamedTypeName(NameState *State);
   Node *parseSourceName(NameState *State);
-  Node *parseUnscopedName(NameState *State);
+  Node *parseUnscopedName(NameState *State, bool *isSubstName);
   Node *parseNestedName(NameState *State);
   Node *parseCtorDtorName(Node *&SoFar, NameState *State);
 
@@ -2641,18 +2678,10 @@
     return getDerived().parseLocalName(State);
 
   Node *Result = nullptr;
-  bool IsSubst = look() == 'S' && look(1) != 't';
-  if (IsSubst) {
-    // A substitution must lead to:
-    //        ::= <unscoped-template-name> <template-args>
-    Result = getDerived().parseSubstitution();
-  } else {
-    // An unscoped name can be one of:
-    //        ::= <unscoped-name>
-    //        ::= <unscoped-template-name> <template-args>
-    Result = getDerived().parseUnscopedName(State);
-  }
-  if (Result == nullptr)
+  bool IsSubst = false;
+
+  Result = getDerived().parseUnscopedName(State, &IsSubst);
+  if (!Result)
     return nullptr;
 
   if (look() == 'I') {
@@ -2715,7 +2744,9 @@
 // [*] extension
 template <typename Derived, typename Alloc>
 Node *
-AbstractManglingParser<Derived, Alloc>::parseUnscopedName(NameState *State) {
+AbstractManglingParser<Derived, Alloc>::parseUnscopedName(NameState *State,
+                                                          bool *IsSubst) {
+
   Node *Std = nullptr;
   if (consumeIf("St")) {
     Std = make<NameType>("std");
@@ -2724,24 +2755,46 @@
   }
   consumeIf('L');
 
-  return getDerived().parseUnqualifiedName(State, Std);
+  Node *Res = nullptr;
+  ModuleName *Module = nullptr;
+  if (look() == 'S') {
+    Node *S = getDerived().parseSubstitution();
+    if (!S)
+      return nullptr;
+    if (S->getKind() == Node::KModuleName)
+      Module = static_cast<ModuleName *>(S);
+    else if (IsSubst && Std == nullptr) {
+      Res = S;
+      *IsSubst = true;
+    } else {
+      return nullptr;
+    }
+  }
+
+  if (Res == nullptr)
+    Res = getDerived().parseUnqualifiedName(State, Std, Module);
+
+  return Res;
 }
 
-// <unqualified-name> ::= <operator-name> [abi-tags]
-//                    ::= <ctor-dtor-name> [<abi-tags>]
-//                    ::= <source-name> [<abi-tags>]
-//                    ::= <unnamed-type-name> [<abi-tags>]
-//                    ::= DC <source-name>+ E      # structured binding declaration
+// <unqualified-name> ::= [<module-name>] <operator-name> [<abi-tags>]
+//                    ::= [<module-name>] <ctor-dtor-name> [<abi-tags>]
+//                    ::= [<module-name>] <source-name> [<abi-tags>]
+//                    ::= [<module-name>] <unnamed-type-name> [<abi-tags>]
+//			# structured binding declaration
+//                    ::= [<module-name>] DC <source-name>+ E
 template <typename Derived, typename Alloc>
-Node *
-AbstractManglingParser<Derived, Alloc>::parseUnqualifiedName(NameState *State,
-                                                             Node *Scope) {
+Node *AbstractManglingParser<Derived, Alloc>::parseUnqualifiedName(
+    NameState *State, Node *Scope, ModuleName *Module) {
+  if (getDerived().parseModuleNameOpt(Module))
+    return nullptr;
+
   Node *Result;
-  if (look() == 'U')
+  if (look() == 'U') {
     Result = getDerived().parseUnnamedTypeName(State);
-  else if (look() >= '1' && look() <= '9')
+  } else if (look() >= '1' && look() <= '9') {
     Result = getDerived().parseSourceName(State);
-  else if (consumeIf("DC")) {
+  } else if (consumeIf("DC")) {
     // Structured binding
     size_t BindingsBegin = Names.size();
     do {
@@ -2753,19 +2806,44 @@
     Result = make<StructuredBindingName>(popTrailingNodeArray(BindingsBegin));
   } else if (look() == 'C' || look() == 'D') {
     // A <ctor-dtor-name>.
-    if (Scope == nullptr)
+    if (Scope == nullptr || Module != nullptr)
       return nullptr;
     Result = getDerived().parseCtorDtorName(Scope, State);
   } else {
     Result = getDerived().parseOperatorName(State);
   }
+
+  if (Module)
+    Result = make<ModuleEntity>(Module, Result);
   if (Result != nullptr)
     Result = getDerived().parseAbiTags(Result);
   if (Result != nullptr && Scope != nullptr)
     Result = make<NestedName>(Scope, Result);
+
   return Result;
 }
 
+// <module-name> ::= <module-subname>
+// 	 	 ::= <module-name> <module-subname>
+//		 ::= <substitution>  # passed in by caller
+// <module-subname> ::= W <source-name>
+//		    ::= W P <source-name>
+template <typename Derived, typename Alloc>
+bool AbstractManglingParser<Derived, Alloc>::parseModuleNameOpt(
+    ModuleName *&Module) {
+  while (consumeIf('W')) {
+    bool IsPartition = consumeIf('P');
+    Node *Sub = getDerived().parseSourceName(nullptr);
+    if (!Sub)
+      return true;
+    Module =
+        static_cast<ModuleName *>(make<ModuleName>(Module, Sub, IsPartition));
+    Subs.push_back(Module);
+  }
+
+  return false;
+}
+
 // <unnamed-type-name> ::= Ut [<nonnegative number>] _
 //                     ::= <closure-type-name>
 //
@@ -3139,25 +3217,35 @@
       if (SoFar != nullptr)
         return nullptr; // Cannot have a prefix.
       SoFar = getDerived().parseDecltype();
-    } else if (look() == 'S') {
-      //          ::= <substitution>
-      if (SoFar != nullptr)
-        return nullptr; // Cannot have a prefix.
-      if (look(1) == 't') {
-        // parseSubstition does not handle 'St'.
-        First += 2;
-        SoFar = make<NameType>("std");
-      } else {
-        SoFar = getDerived().parseSubstitution();
-      }
-      if (SoFar == nullptr)
-        return nullptr;
-      continue; // Do not push a new substitution.
     } else {
-      consumeIf('L'); // extension
+      ModuleName *Module = nullptr;
+      bool IsLocal = consumeIf('L'); // extension
+
+      if (look() == 'S') {
+        //          ::= <substitution>
+        Node *S = nullptr;
+        if (look(1) == 't') {
+          First += 2;
+          S = make<NameType>("std");
+        } else {
+          S = getDerived().parseSubstitution();
+        }
+        if (!S)
+          return nullptr;
+        if (S->getKind() == Node::KModuleName) {
+          Module = static_cast<ModuleName *>(S);
+        } else if (SoFar != nullptr || IsLocal) {
+          return nullptr; // Cannot have a prefix.
+        } else {
+          SoFar = S;
+          continue; // Do not push a new substitution.
+        }
+      }
+
       //          ::= [<prefix>] <unqualified-name>
-      SoFar = getDerived().parseUnqualifiedName(State, SoFar);
+      SoFar = getDerived().parseUnqualifiedName(State, SoFar, Module);
     }
+
     if (SoFar == nullptr)
       return nullptr;
     Subs.push_back(SoFar);
@@ -3970,8 +4058,9 @@
   //             ::= <substitution>  # See Compression below
   case 'S': {
     if (look(1) != 't') {
-      Result = getDerived().parseSubstitution();
-      if (Result == nullptr)
+      bool IsSubst = false;
+      Result = getDerived().parseUnscopedName(nullptr, &IsSubst);
+      if (!Result)
         return nullptr;
 
       // Sub could be either of:
@@ -3984,12 +4073,14 @@
       // If this is followed by some <template-args>, and we're permitted to
       // parse them, take the second production.
 
-      if (TryToParseTemplateArgs && look() == 'I') {
+      if (look() == 'I' && (!IsSubst || TryToParseTemplateArgs)) {
+        if (!IsSubst)
+          Subs.push_back(Result);
         Node *TA = getDerived().parseTemplateArgs();
         if (TA == nullptr)
           return nullptr;
         Result = make<NameWithTemplateArgs>(Result, TA);
-      } else {
+      } else if (IsSubst) {
         // If all we parsed was a substitution, don't re-insert into the
         // substitution table.
         return Result;
@@ -4738,14 +4829,17 @@
 //                    # second call-offset is result adjustment
 //                ::= T <call-offset> <base encoding>
 //                    # base is the nominal target function of thunk
-//                ::= GV <object name> # Guard variable for one-time initialization
+//                # Guard variable for one-time initialization
+//                ::= GV <object name>
 //                                     # No <type>
 //                ::= TW <object name> # Thread-local wrapper
 //                ::= TH <object name> # Thread-local initialization
 //                ::= GR <object name> _             # First temporary
 //                ::= GR <object name> <seq-id> _    # Subsequent temporaries
-//      extension ::= TC <first type> <number> _ <second type> # construction vtable for second-in-first
+//                # construction vtable for second-in-first
+//      extension ::= TC <first type> <number> _ <second type>
 //      extension ::= GR <object name> # reference temporary for object
+//      extension ::= GI <module name> # module global initializer
 template <typename Derived, typename Alloc>
 Node *AbstractManglingParser<Derived, Alloc>::parseSpecialName() {
   switch (look()) {
@@ -4872,6 +4966,16 @@
         return nullptr;
       return make<SpecialName>("reference temporary for ", Name);
     }
+    // GI <module-name> v
+    case 'I': {
+      First += 2;
+      ModuleName *Module = nullptr;
+      if (getDerived().parseModuleNameOpt(Module))
+        return nullptr;
+      if (Module == nullptr)
+        return nullptr;
+      return make<SpecialName>("initializer for module ", Module);
+    }
     }
   }
   return nullptr;
diff --git a/test/test_demangle.pass.cpp b/test/test_demangle.pass.cpp
index 9ad6929..258bf1a 100644
--- a/test/test_demangle.pass.cpp
+++ b/test/test_demangle.pass.cpp
@@ -29867,6 +29867,31 @@
     {"_Z3TPLIiET_S0_", "int TPL<int>(int)"},
 
     {"_ZN1XawEv", "X::operator co_await()"},
+
+    // C++20 modules
+    {"_ZN5Outer5InnerW3FOO2FnERNS0_1XE", "Outer::Inner::Fn@FOO(Outer::Inner::X&)"},
+    {"_ZN5OuterW3FOO5Inner2FnERNS1_1XE", "Outer::Inner@FOO::Fn(Outer::Inner@FOO::X&)"},
+    {"_ZN4Quux4TotoW3FooW3Bar3BazEPNS0_S2_5PlughE", "Quux::Toto::Baz@Foo.Bar(Quux::Toto::Plugh@Foo.Bar*)"},
+    {"_ZW6Module1fNS_1a1bENS0_1cE", "f@Module(a@Module::b, a@Module::c)"},
+    {"_ZN3BobW3FOOW3BAR3BarEPS1_1APNS_S1_1BE", "Bob::Bar@FOO.BAR(A@FOO.BAR*, Bob::B@FOO.BAR*)"},
+    {"_ZW3FOOW3BAR3FooPS0_1APN3BobS0_1BE", "Foo@FOO.BAR(A@FOO.BAR*, Bob::B@FOO.BAR*)"},
+    {"_ZN3BobW3FOOW3BAZ3FooEPS0_W3BAR1APNS_S2_1BE", "Bob::Foo@FOO.BAZ(A@FOO.BAR*, Bob::B@FOO.BAR*)"},
+    {"_ZW3FOOW3BAZ3BarPS_W3BAR1APN3BobS1_1BE", "Bar@FOO.BAZ(A@FOO.BAR*, Bob::B@FOO.BAR*)"},
+    {"_ZNW3FOO3TPLIS_3OneE1MEPS1_", "TPL@FOO<One@FOO>::M(One@FOO*)"},
+    {"_ZNW3FOO3TPLIS_3OneE1NIS_3TwoEEvPS1_PT_", "void TPL@FOO<One@FOO>::N<Two@FOO>(One@FOO*, Two@FOO*)"},
+    {"_ZN3NMSW3FOO3TPLINS_S0_3OneEE1MEPS2_", "NMS::TPL@FOO<NMS::One@FOO>::M(NMS::One@FOO*)"},
+    {"_ZN3NMSW3FOO3TPLINS_S0_3OneEE1NINS_S0_3TwoEEEvPS2_PT_",
+     "void NMS::TPL@FOO<NMS::One@FOO>::N<NMS::Two@FOO>(NMS::One@FOO*, NMS::Two@FOO*)"},
+    {"_ZNStW3STD9allocatorIiE1MEPi", "std::allocator@STD<int>::M(int*)"},
+    {"_ZNStW3STD9allocatorIiE1NIfEEPT_Pi", "float* std::allocator@STD<int>::N<float>(int*)"},
+    {"_ZNStW3STD9allocatorI4PoohE1MEPS1_", "std::allocator@STD<Pooh>::M(Pooh*)"},
+    {"_ZNStW3STD9allocatorI4PoohE1NI6PigletEEPT_PS1_", "Piglet* std::allocator@STD<Pooh>::N<Piglet>(Pooh*)"},
+    {"_ZW3FooDC1a1bE", "[a, b]@Foo"},
+    {"_ZN1NW3FooDC1a1bEE", "N::[a, b]@Foo"},
+    {"_ZN3NMSW3MOD3FooB3ABIEv", "NMS::Foo@MOD[abi:ABI]()"},
+    {"_ZGIW3Foo", "initializer for module Foo"},
+    {"_ZGIW3FooW3Bar", "initializer for module Foo.Bar"},
+    {"_ZGIW3FooWP3BarW3Baz", "initializer for module Foo:Bar.Baz"},
 };
 
 const unsigned N = sizeof(cases) / sizeof(cases[0]);
@@ -29954,6 +29979,11 @@
     "_ZNDTUt_Ev",
 
     "_ZN1fIXawLi0EEEEvv",
+
+    "_ZNWUt_3FOOEv",
+    "_ZWDC3FOOEv",
+    "_ZGI3Foo",
+    "_ZGIW3Foov",
 };
 
 const unsigned NI = sizeof(invalid_cases) / sizeof(invalid_cases[0]);
@@ -30006,8 +30036,6 @@
 }
 
 const char *xfail_cases[] = {
-    "_ZW6FooBarE2f3v", // C++ modules TS
-
     // FIXME: Why does clang generate the "cp" expr?
     "_ZN5test11bIsEEDTcp3foocvT__EEES1_",
 };