[clang][ExtractAPI] Refactor serializer to the CRTP

Refactor SerializerBase and SymbolGraphSerializer to use a visitor pattern described by the CRTP.

Reviewed By: dang

Differential Revision: https://reviews.llvm.org/D151477

GitOrigin-RevId: 06ff9770477d8c7378047b0171db4b25eba5d8dd
diff --git a/include/clang/ExtractAPI/Serialization/SerializerBase.h b/include/clang/ExtractAPI/Serialization/SerializerBase.h
index d8aa826..006e92b 100644
--- a/include/clang/ExtractAPI/Serialization/SerializerBase.h
+++ b/include/clang/ExtractAPI/Serialization/SerializerBase.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 ///
 /// \file
-/// This file defines the ExtractAPI APISerializer interface.
+/// This file defines the ExtractAPI APISetVisitor interface.
 ///
 //===----------------------------------------------------------------------===//
 
@@ -15,47 +15,107 @@
 #define LLVM_CLANG_EXTRACTAPI_SERIALIZATION_SERIALIZERBASE_H
 
 #include "clang/ExtractAPI/API.h"
-#include "clang/ExtractAPI/APIIgnoresList.h"
-#include "llvm/Support/raw_ostream.h"
 
 namespace clang {
 namespace extractapi {
 
-/// Common options to customize the serializer output.
-struct APISerializerOption {
-  /// Do not include unnecessary whitespaces to save space.
-  bool Compact;
-};
-
-/// The base interface of serializers for API information.
-class APISerializer {
+/// The base interface of visitors for API information.
+template <typename Derived> class APISetVisitor {
 public:
-  /// Serialize the API information to \p os.
-  virtual void serialize(raw_ostream &os) = 0;
+  void traverseAPISet() {
+    getDerived()->traverseGlobalVariableRecords();
+
+    getDerived()->traverseGlobalFunctionRecords();
+
+    getDerived()->traverseEnumRecords();
+
+    getDerived()->traverseStructRecords();
+
+    getDerived()->traverseObjCInterfaces();
+
+    getDerived()->traverseObjCProtocols();
+
+    getDerived()->traverseMacroDefinitionRecords();
+
+    getDerived()->traverseTypedefRecords();
+  }
+
+  void traverseGlobalFunctionRecords() {
+    for (const auto &GlobalFunction : API.getGlobalFunctions())
+      getDerived()->visitGlobalFunctionRecord(*GlobalFunction.second);
+  }
+
+  void traverseGlobalVariableRecords() {
+    for (const auto &GlobalVariable : API.getGlobalVariables())
+      getDerived()->visitGlobalVariableRecord(*GlobalVariable.second);
+  }
+
+  void traverseEnumRecords() {
+    for (const auto &Enum : API.getEnums())
+      getDerived()->visitEnumRecord(*Enum.second);
+  }
+
+  void traverseStructRecords() {
+    for (const auto &Struct : API.getStructs())
+      getDerived()->visitStructRecord(*Struct.second);
+  }
+
+  void traverseObjCInterfaces() {
+    for (const auto &Interface : API.getObjCInterfaces())
+      getDerived()->visitObjCContainerRecord(*Interface.second);
+  }
+
+  void traverseObjCProtocols() {
+    for (const auto &Protocol : API.getObjCProtocols())
+      getDerived()->visitObjCContainerRecord(*Protocol.second);
+  }
+
+  void traverseMacroDefinitionRecords() {
+    for (const auto &Macro : API.getMacros())
+      getDerived()->visitMacroDefinitionRecord(*Macro.second);
+  }
+
+  void traverseTypedefRecords() {
+    for (const auto &Typedef : API.getTypedefs())
+      getDerived()->visitTypedefRecord(*Typedef.second);
+  }
+
+  /// Visit a global function record.
+  void visitGlobalFunctionRecord(const GlobalFunctionRecord &Record){};
+
+  /// Visit a global variable record.
+  void visitGlobalVariableRecord(const GlobalVariableRecord &Record){};
+
+  /// Visit an enum record.
+  void visitEnumRecord(const EnumRecord &Record){};
+
+  /// Visit a struct record.
+  void visitStructRecord(const StructRecord &Record){};
+
+  /// Visit an Objective-C container record.
+  void visitObjCContainerRecord(const ObjCContainerRecord &Record){};
+
+  /// Visit a macro definition record.
+  void visitMacroDefinitionRecord(const MacroDefinitionRecord &Record){};
+
+  /// Visit a typedef record.
+  void visitTypedefRecord(const TypedefRecord &Record){};
 
 protected:
   const APISet &API;
 
-  /// The list of symbols to ignore.
-  ///
-  /// Note: This should be consulted before emitting a symbol.
-  const APIIgnoresList &IgnoresList;
-
-  APISerializerOption Options;
-
 public:
-  APISerializer() = delete;
-  APISerializer(const APISerializer &) = delete;
-  APISerializer(APISerializer &&) = delete;
-  APISerializer &operator=(const APISerializer &) = delete;
-  APISerializer &operator=(APISerializer &&) = delete;
+  APISetVisitor() = delete;
+  APISetVisitor(const APISetVisitor &) = delete;
+  APISetVisitor(APISetVisitor &&) = delete;
+  APISetVisitor &operator=(const APISetVisitor &) = delete;
+  APISetVisitor &operator=(APISetVisitor &&) = delete;
 
 protected:
-  APISerializer(const APISet &API, const APIIgnoresList &IgnoresList,
-                APISerializerOption Options = {})
-      : API(API), IgnoresList(IgnoresList), Options(Options) {}
+  APISetVisitor(const APISet &API) : API(API) {}
+  ~APISetVisitor() = default;
 
-  virtual ~APISerializer() = default;
+  Derived *getDerived() { return static_cast<Derived *>(this); };
 };
 
 } // namespace extractapi
diff --git a/include/clang/ExtractAPI/Serialization/SymbolGraphSerializer.h b/include/clang/ExtractAPI/Serialization/SymbolGraphSerializer.h
index 6639082..e77903f 100644
--- a/include/clang/ExtractAPI/Serialization/SymbolGraphSerializer.h
+++ b/include/clang/ExtractAPI/Serialization/SymbolGraphSerializer.h
@@ -9,8 +9,8 @@
 /// \file
 /// This file defines the SymbolGraphSerializer class.
 ///
-/// Implement an APISerializer for the Symbol Graph format for ExtractAPI.
-/// See https://github.com/apple/swift-docc-symbolkit.
+/// Implement an APISetVisitor to serialize the APISet into the Symbol Graph
+/// format for ExtractAPI. See https://github.com/apple/swift-docc-symbolkit.
 ///
 //===----------------------------------------------------------------------===//
 
@@ -31,14 +31,18 @@
 
 using namespace llvm::json;
 
-/// The serializer that organizes API information in the Symbol Graph format.
+/// Common options to customize the visitor output.
+struct SymbolGraphSerializerOption {
+  /// Do not include unnecessary whitespaces to save space.
+  bool Compact;
+};
+
+/// The visitor that organizes API information in the Symbol Graph format.
 ///
 /// The Symbol Graph format (https://github.com/apple/swift-docc-symbolkit)
 /// models an API set as a directed graph, where nodes are symbol declarations,
 /// and edges are relationships between the connected symbols.
-class SymbolGraphSerializer : public APISerializer {
-  virtual void anchor();
-
+class SymbolGraphSerializer : public APISetVisitor<SymbolGraphSerializer> {
   /// A JSON array of formatted symbols in \c APISet.
   Array Symbols;
 
@@ -48,7 +52,7 @@
   /// The Symbol Graph format version used by this serializer.
   static const VersionTuple FormatVersion;
 
-  /// Indicates whether child symbols should be serialized. This is mainly
+  /// Indicates whether child symbols should be visited. This is mainly
   /// useful for \c serializeSingleSymbolSGF.
   bool ShouldRecurse;
 
@@ -59,9 +63,8 @@
   /// Symbol Graph.
   Object serialize();
 
-  /// Implement the APISerializer::serialize interface. Wrap serialize(void) and
-  /// write out the serialized JSON object to \p os.
-  void serialize(raw_ostream &os) override;
+  ///  Wrap serialize(void) and write out the serialized JSON object to \p os.
+  void serialize(raw_ostream &os);
 
   /// Serialize a single symbol SGF. This is primarily used for libclang.
   ///
@@ -136,35 +139,44 @@
   void serializeRelationship(RelationshipKind Kind, SymbolReference Source,
                              SymbolReference Target);
 
-  /// Serialize a global function record.
-  void serializeGlobalFunctionRecord(const GlobalFunctionRecord &Record);
+protected:
+  /// The list of symbols to ignore.
+  ///
+  /// Note: This should be consulted before emitting a symbol.
+  const APIIgnoresList &IgnoresList;
 
-  /// Serialize a global variable record.
-  void serializeGlobalVariableRecord(const GlobalVariableRecord &Record);
-
-  /// Serialize an enum record.
-  void serializeEnumRecord(const EnumRecord &Record);
-
-  /// Serialize a struct record.
-  void serializeStructRecord(const StructRecord &Record);
-
-  /// Serialize an Objective-C container record.
-  void serializeObjCContainerRecord(const ObjCContainerRecord &Record);
-
-  /// Serialize a macro definition record.
-  void serializeMacroDefinitionRecord(const MacroDefinitionRecord &Record);
-
-  /// Serialize a typedef record.
-  void serializeTypedefRecord(const TypedefRecord &Record);
-
-  void serializeSingleRecord(const APIRecord *Record);
+  SymbolGraphSerializerOption Options;
 
 public:
+  /// Visit a global function record.
+  void visitGlobalFunctionRecord(const GlobalFunctionRecord &Record);
+
+  /// Visit a global variable record.
+  void visitGlobalVariableRecord(const GlobalVariableRecord &Record);
+
+  /// Visit an enum record.
+  void visitEnumRecord(const EnumRecord &Record);
+
+  /// Visit a struct record.
+  void visitStructRecord(const StructRecord &Record);
+
+  /// Visit an Objective-C container record.
+  void visitObjCContainerRecord(const ObjCContainerRecord &Record);
+
+  /// Visit a macro definition record.
+  void visitMacroDefinitionRecord(const MacroDefinitionRecord &Record);
+
+  /// Visit a typedef record.
+  void visitTypedefRecord(const TypedefRecord &Record);
+
+  /// Serialize a single record.
+  void serializeSingleRecord(const APIRecord *Record);
+
   SymbolGraphSerializer(const APISet &API, const APIIgnoresList &IgnoresList,
-                        APISerializerOption Options = {},
+                        SymbolGraphSerializerOption Options = {},
                         bool ShouldRecurse = true)
-      : APISerializer(API, IgnoresList, Options), ShouldRecurse(ShouldRecurse) {
-  }
+      : APISetVisitor(API), ShouldRecurse(ShouldRecurse),
+        IgnoresList(IgnoresList), Options(Options) {}
 };
 
 } // namespace extractapi
diff --git a/lib/ExtractAPI/CMakeLists.txt b/lib/ExtractAPI/CMakeLists.txt
index 153d4b9..b43fe74 100644
--- a/lib/ExtractAPI/CMakeLists.txt
+++ b/lib/ExtractAPI/CMakeLists.txt
@@ -9,7 +9,6 @@
   AvailabilityInfo.cpp
   ExtractAPIConsumer.cpp
   DeclarationFragments.cpp
-  Serialization/SerializerBase.cpp
   Serialization/SymbolGraphSerializer.cpp
   TypedefUnderlyingTypeResolver.cpp
 
diff --git a/lib/ExtractAPI/Serialization/SerializerBase.cpp b/lib/ExtractAPI/Serialization/SerializerBase.cpp
deleted file mode 100644
index 71fd25b..0000000
--- a/lib/ExtractAPI/Serialization/SerializerBase.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-//===- ExtractAPI/Serialization/SerializerBase.cpp --------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file
-/// This file implements the APISerializer interface.
-///
-//===----------------------------------------------------------------------===//
-
-#include "clang/ExtractAPI/Serialization/SerializerBase.h"
-#include "llvm/Support/raw_ostream.h"
-
-using namespace clang::extractapi;
-
-void APISerializer::serialize(llvm::raw_ostream &os) {}
diff --git a/lib/ExtractAPI/Serialization/SymbolGraphSerializer.cpp b/lib/ExtractAPI/Serialization/SymbolGraphSerializer.cpp
index 7676c74..534e928 100644
--- a/lib/ExtractAPI/Serialization/SymbolGraphSerializer.cpp
+++ b/lib/ExtractAPI/Serialization/SymbolGraphSerializer.cpp
@@ -14,16 +14,11 @@
 #include "clang/ExtractAPI/Serialization/SymbolGraphSerializer.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/Version.h"
-#include "clang/ExtractAPI/API.h"
-#include "clang/ExtractAPI/APIIgnoresList.h"
 #include "clang/ExtractAPI/DeclarationFragments.h"
-#include "clang/ExtractAPI/Serialization/SerializerBase.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Compiler.h"
-#include "llvm/Support/JSON.h"
 #include "llvm/Support/Path.h"
 #include "llvm/Support/VersionTuple.h"
 #include <optional>
@@ -541,19 +536,16 @@
 Array generateParentContexts(const RecordTy &Record, const APISet &API,
                              Language Lang) {
   Array ParentContexts;
-  generatePathComponents(Record, API,
-                         [Lang, &ParentContexts](const PathComponent &PC) {
-                           ParentContexts.push_back(
-                               serializeParentContext(PC, Lang));
-                         });
+  generatePathComponents(
+      Record, API, [Lang, &ParentContexts](const PathComponent &PC) {
+        ParentContexts.push_back(serializeParentContext(PC, Lang));
+      });
 
   return ParentContexts;
 }
 
 } // namespace
 
-void SymbolGraphSerializer::anchor() {}
-
 /// Defines the format version emitted by SymbolGraphSerializer.
 const VersionTuple SymbolGraphSerializer::FormatVersion{0, 5, 3};
 
@@ -670,7 +662,7 @@
   Relationships.emplace_back(std::move(Relationship));
 }
 
-void SymbolGraphSerializer::serializeGlobalFunctionRecord(
+void SymbolGraphSerializer::visitGlobalFunctionRecord(
     const GlobalFunctionRecord &Record) {
   auto Obj = serializeAPIRecord(Record);
   if (!Obj)
@@ -679,7 +671,7 @@
   Symbols.emplace_back(std::move(*Obj));
 }
 
-void SymbolGraphSerializer::serializeGlobalVariableRecord(
+void SymbolGraphSerializer::visitGlobalVariableRecord(
     const GlobalVariableRecord &Record) {
   auto Obj = serializeAPIRecord(Record);
   if (!Obj)
@@ -688,7 +680,7 @@
   Symbols.emplace_back(std::move(*Obj));
 }
 
-void SymbolGraphSerializer::serializeEnumRecord(const EnumRecord &Record) {
+void SymbolGraphSerializer::visitEnumRecord(const EnumRecord &Record) {
   auto Enum = serializeAPIRecord(Record);
   if (!Enum)
     return;
@@ -697,7 +689,7 @@
   serializeMembers(Record, Record.Constants);
 }
 
-void SymbolGraphSerializer::serializeStructRecord(const StructRecord &Record) {
+void SymbolGraphSerializer::visitStructRecord(const StructRecord &Record) {
   auto Struct = serializeAPIRecord(Record);
   if (!Struct)
     return;
@@ -706,7 +698,7 @@
   serializeMembers(Record, Record.Fields);
 }
 
-void SymbolGraphSerializer::serializeObjCContainerRecord(
+void SymbolGraphSerializer::visitObjCContainerRecord(
     const ObjCContainerRecord &Record) {
   auto ObjCContainer = serializeAPIRecord(Record);
   if (!ObjCContainer)
@@ -743,7 +735,7 @@
   }
 }
 
-void SymbolGraphSerializer::serializeMacroDefinitionRecord(
+void SymbolGraphSerializer::visitMacroDefinitionRecord(
     const MacroDefinitionRecord &Record) {
   auto Macro = serializeAPIRecord(Record);
 
@@ -758,28 +750,28 @@
   case APIRecord::RK_Unknown:
     llvm_unreachable("Records should have a known kind!");
   case APIRecord::RK_GlobalFunction:
-    serializeGlobalFunctionRecord(*cast<GlobalFunctionRecord>(Record));
+    visitGlobalFunctionRecord(*cast<GlobalFunctionRecord>(Record));
     break;
   case APIRecord::RK_GlobalVariable:
-    serializeGlobalVariableRecord(*cast<GlobalVariableRecord>(Record));
+    visitGlobalVariableRecord(*cast<GlobalVariableRecord>(Record));
     break;
   case APIRecord::RK_Enum:
-    serializeEnumRecord(*cast<EnumRecord>(Record));
+    visitEnumRecord(*cast<EnumRecord>(Record));
     break;
   case APIRecord::RK_Struct:
-    serializeStructRecord(*cast<StructRecord>(Record));
+    visitStructRecord(*cast<StructRecord>(Record));
     break;
   case APIRecord::RK_ObjCInterface:
-    serializeObjCContainerRecord(*cast<ObjCInterfaceRecord>(Record));
+    visitObjCContainerRecord(*cast<ObjCInterfaceRecord>(Record));
     break;
   case APIRecord::RK_ObjCProtocol:
-    serializeObjCContainerRecord(*cast<ObjCProtocolRecord>(Record));
+    visitObjCContainerRecord(*cast<ObjCProtocolRecord>(Record));
     break;
   case APIRecord::RK_MacroDefinition:
-    serializeMacroDefinitionRecord(*cast<MacroDefinitionRecord>(Record));
+    visitMacroDefinitionRecord(*cast<MacroDefinitionRecord>(Record));
     break;
   case APIRecord::RK_Typedef:
-    serializeTypedefRecord(*cast<TypedefRecord>(Record));
+    visitTypedefRecord(*cast<TypedefRecord>(Record));
     break;
   default:
     if (auto Obj = serializeAPIRecord(*Record)) {
@@ -793,8 +785,7 @@
   }
 }
 
-void SymbolGraphSerializer::serializeTypedefRecord(
-    const TypedefRecord &Record) {
+void SymbolGraphSerializer::visitTypedefRecord(const TypedefRecord &Record) {
   // Typedefs of anonymous types have their entries unified with the underlying
   // type.
   bool ShouldDrop = Record.UnderlyingType.Name.empty();
@@ -814,35 +805,7 @@
 }
 
 Object SymbolGraphSerializer::serialize() {
-  // Serialize global variables in the API set.
-  for (const auto &GlobalVar : API.getGlobalVariables())
-    serializeGlobalVariableRecord(*GlobalVar.second);
-
-  for (const auto &GlobalFunction : API.getGlobalFunctions())
-    serializeGlobalFunctionRecord(*GlobalFunction.second);
-
-  // Serialize enum records in the API set.
-  for (const auto &Enum : API.getEnums())
-    serializeEnumRecord(*Enum.second);
-
-  // Serialize struct records in the API set.
-  for (const auto &Struct : API.getStructs())
-    serializeStructRecord(*Struct.second);
-
-  // Serialize Objective-C interface records in the API set.
-  for (const auto &ObjCInterface : API.getObjCInterfaces())
-    serializeObjCContainerRecord(*ObjCInterface.second);
-
-  // Serialize Objective-C protocol records in the API set.
-  for (const auto &ObjCProtocol : API.getObjCProtocols())
-    serializeObjCContainerRecord(*ObjCProtocol.second);
-
-  for (const auto &Macro : API.getMacros())
-    serializeMacroDefinitionRecord(*Macro.second);
-
-  for (const auto &Typedef : API.getTypedefs())
-    serializeTypedefRecord(*Typedef.second);
-
+  traverseAPISet();
   return serializeCurrentGraph();
 }