[CSSPGO] Sorting nodes in a cycle of profiled call graph.

For nodes that are in a cycle of a profiled call graph, the current order the underlying scc_iter computes purely depends on how those nodes are reached from outside the SCC and inside the SCC, based on the Tarjan algorithm. This does not honor profile edge hotness, thus does not gurantee hot callsites to be inlined prior to cold callsites. To mitigate that, I'm adding an extra sorter on top of scc_iter to sort scc functions in the order of callsite hotness, instead of changing the internal of scc_iter.

Sorting on callsite hotness can be optimally based on detecting cycles on a directed call graph, i.e, to remove the coldest edge until a cycle is broken. However, detecting cycles isn't cheap. I'm using an MST-based approach which is faster and appear to deliver some performance wins.

Reviewed By: wenlei

Differential Revision: https://reviews.llvm.org/D114204
diff --git a/llvm/include/llvm/ADT/SCCIterator.h b/llvm/include/llvm/ADT/SCCIterator.h
index 8a7c0a7..ad35e09 100644
--- a/llvm/include/llvm/ADT/SCCIterator.h
+++ b/llvm/include/llvm/ADT/SCCIterator.h
@@ -28,6 +28,10 @@
 #include <cassert>
 #include <cstddef>
 #include <iterator>
+#include <queue>
+#include <set>
+#include <unordered_map>
+#include <unordered_set>
 #include <vector>
 
 namespace llvm {
@@ -234,6 +238,135 @@
   return scc_iterator<T>::end(G);
 }
 
+/// Sort the nodes of a directed SCC in the decreasing order of the edge
+/// weights. The instantiating GraphT type should have weighted edge type
+/// declared in its graph traits in order to use this iterator.
+///
+/// This is implemented using Kruskal's minimal spanning tree algorithm followed
+/// by a BFS walk. First a maximum spanning tree (forest) is built based on all
+/// edges within the SCC collection. Then a BFS walk is initiated on tree nodes
+/// that do not have a predecessor. Finally, the BFS order computed is the
+/// traversal order of the nodes of the SCC. Such order ensures that
+/// high-weighted edges are visited first during the tranversal.
+template <class GraphT, class GT = GraphTraits<GraphT>>
+class scc_member_iterator {
+  using NodeType = typename GT::NodeType;
+  using EdgeType = typename GT::EdgeType;
+  using NodesType = std::vector<NodeType *>;
+
+  // Auxilary node information used during the MST calculation.
+  struct NodeInfo {
+    NodeInfo *Group = this;
+    uint32_t Rank = 0;
+    bool Visited = true;
+  };
+
+  // Find the root group of the node and compress the path from node to the
+  // root.
+  NodeInfo *find(NodeInfo *Node) {
+    if (Node->Group != Node)
+      Node->Group = find(Node->Group);
+    return Node->Group;
+  }
+
+  // Union the source and target node into the same group and return true.
+  // Returns false if they are already in the same group.
+  bool unionGroups(const EdgeType *Edge) {
+    NodeInfo *G1 = find(&NodeInfoMap[Edge->Source]);
+    NodeInfo *G2 = find(&NodeInfoMap[Edge->Target]);
+
+    // If the edge forms a cycle, do not add it to MST
+    if (G1 == G2)
+      return false;
+
+    // Make the smaller rank tree a direct child or the root of high rank tree.
+    if (G1->Rank < G1->Rank)
+      G1->Group = G2;
+    else {
+      G2->Group = G1;
+      // If the ranks are the same, increment root of one tree by one.
+      if (G1->Rank == G2->Rank)
+        G2->Rank++;
+    }
+    return true;
+  }
+
+  std::unordered_map<NodeType *, NodeInfo> NodeInfoMap;
+  NodesType Nodes;
+
+public:
+  scc_member_iterator(const NodesType &InputNodes);
+
+  NodesType &operator*() { return Nodes; }
+};
+
+template <class GraphT, class GT>
+scc_member_iterator<GraphT, GT>::scc_member_iterator(
+    const NodesType &InputNodes) {
+  if (InputNodes.size() <= 1) {
+    Nodes = InputNodes;
+    return;
+  }
+
+  // Initialize auxilary node information.
+  NodeInfoMap.clear();
+  for (auto *Node : InputNodes) {
+    // This is specifically used to construct a `NodeInfo` object in place. An
+    // insert operation will involve a copy construction which invalidate the
+    // initial value of the `Group` field which should be `this`.
+    (void)NodeInfoMap[Node].Group;
+  }
+
+  // Sort edges by weights.
+  struct EdgeComparer {
+    bool operator()(const EdgeType *L, const EdgeType *R) const {
+      return L->Weight > R->Weight;
+    }
+  };
+
+  std::multiset<const EdgeType *, EdgeComparer> SortedEdges;
+  for (auto *Node : InputNodes) {
+    for (auto &Edge : Node->Edges) {
+      if (NodeInfoMap.count(Edge.Target))
+        SortedEdges.insert(&Edge);
+    }
+  }
+
+  // Traverse all the edges and compute the Maximum Weight Spanning Tree
+  // using Kruskal's algorithm.
+  std::unordered_set<const EdgeType *> MSTEdges;
+  for (auto *Edge : SortedEdges) {
+    if (unionGroups(Edge))
+      MSTEdges.insert(Edge);
+  }
+
+  // Do BFS on MST, starting from nodes that have no incoming edge. These nodes
+  // are "roots" of the MST forest. This ensures that nodes are visited before
+  // their decsendents are, thus ensures hot edges are processed before cold
+  // edges, based on how MST is computed.
+  for (const auto *Edge : MSTEdges)
+    NodeInfoMap[Edge->Target].Visited = false;
+
+  std::queue<NodeType *> Queue;
+  for (auto &Node : NodeInfoMap)
+    if (Node.second.Visited)
+      Queue.push(Node.first);
+
+  while (!Queue.empty()) {
+    auto *Node = Queue.front();
+    Queue.pop();
+    Nodes.push_back(Node);
+    for (auto &Edge : Node->Edges) {
+      if (MSTEdges.count(&Edge) && !NodeInfoMap[Edge.Target].Visited) {
+        NodeInfoMap[Edge.Target].Visited = true;
+        Queue.push(Edge.Target);
+      }
+    }
+  }
+
+  assert(InputNodes.size() == Nodes.size() && "missing nodes in MST");
+  std::reverse(Nodes.begin(), Nodes.end());
+}
 } // end namespace llvm
 
 #endif // LLVM_ADT_SCCITERATOR_H
diff --git a/llvm/include/llvm/Transforms/IPO/ProfiledCallGraph.h b/llvm/include/llvm/Transforms/IPO/ProfiledCallGraph.h
index 6e45f8f..429fcbd 100644
--- a/llvm/include/llvm/Transforms/IPO/ProfiledCallGraph.h
+++ b/llvm/include/llvm/Transforms/IPO/ProfiledCallGraph.h
@@ -24,22 +24,47 @@
 namespace llvm {
 namespace sampleprof {
 
-struct ProfiledCallGraphNode {
-  ProfiledCallGraphNode(StringRef FName = StringRef()) : Name(FName) {}
-  StringRef Name;
+struct ProfiledCallGraphNode;
 
-  struct ProfiledCallGraphNodeComparer {
-    bool operator()(const ProfiledCallGraphNode *L,
-                    const ProfiledCallGraphNode *R) const {
-      return L->Name < R->Name;
+struct ProfiledCallGraphEdge {
+  ProfiledCallGraphEdge(ProfiledCallGraphNode *Source,
+                        ProfiledCallGraphNode *Target, uint64_t Weight)
+      : Source(Source), Target(Target), Weight(Weight) {}
+  ProfiledCallGraphNode *Source;
+  ProfiledCallGraphNode *Target;
+  uint64_t Weight;
+
+  // The call destination is the only important data here,
+  // allow to transparently unwrap into it.
+  operator ProfiledCallGraphNode *() const { return Target; }
+};
+
+struct ProfiledCallGraphNode {
+
+  // Sort edges by callee names only since all edges to be compared are from
+  // same caller. Edge weights are not considered either because for the same
+  // callee only the edge with the largest weight is added to the edge set.
+  struct ProfiledCallGraphEdgeComparer {
+    bool operator()(const ProfiledCallGraphEdge &L,
+                    const ProfiledCallGraphEdge &R) const {
+      return L.Target->Name < R.Target->Name;
     }
   };
-  std::set<ProfiledCallGraphNode *, ProfiledCallGraphNodeComparer> Callees;
+
+  using iterator = std::set<ProfiledCallGraphEdge>::iterator;
+  using const_iterator = std::set<ProfiledCallGraphEdge>::const_iterator;
+  using edge = ProfiledCallGraphEdge;
+  using edges = std::set<ProfiledCallGraphEdge, ProfiledCallGraphEdgeComparer>;
+
+  ProfiledCallGraphNode(StringRef FName = StringRef()) : Name(FName) {}
+
+  StringRef Name;
+  edges Edges;
 };
 
 class ProfiledCallGraph {
 public:
-  using iterator = std::set<ProfiledCallGraphNode *>::iterator;
+  using iterator = std::set<ProfiledCallGraphEdge>::iterator;
 
   // Constructor for non-CS profile.
   ProfiledCallGraph(SampleProfileMap &ProfileMap) {
@@ -63,8 +88,9 @@
     while (!Queue.empty()) {
       ContextTrieNode *Caller = Queue.front();
       Queue.pop();
-      // Add calls for context. When AddNodeWithSamplesOnly is true, both caller
-      // and callee need to have context profile.
+      FunctionSamples *CallerSamples = Caller->getFunctionSamples();
+
+      // Add calls for context.
       // Note that callsite target samples are completely ignored since they can
       // conflict with the context edges, which are formed by context
       // compression during profile generation, for cyclic SCCs. This may
@@ -74,31 +100,61 @@
         ContextTrieNode *Callee = &Child.second;
         addProfiledFunction(ContextTracker.getFuncNameFor(Callee));
         Queue.push(Callee);
+
+        // Fetch edge weight from the profile.
+        uint64_t Weight;
+        FunctionSamples *CalleeSamples = Callee->getFunctionSamples();
+        if (!CalleeSamples || !CallerSamples) {
+          Weight = 0;
+        } else {
+          uint64_t CalleeEntryCount = CalleeSamples->getEntrySamples();
+          uint64_t CallsiteCount = 0;
+          LineLocation Callsite = Callee->getCallSiteLoc();
+          if (auto CallTargets = CallerSamples->findCallTargetMapAt(Callsite)) {
+            SampleRecord::CallTargetMap &TargetCounts = CallTargets.get();
+            auto It = TargetCounts.find(CalleeSamples->getName());
+            if (It != TargetCounts.end())
+              CallsiteCount = It->second;
+          }
+          Weight = std::max(CallsiteCount, CalleeEntryCount);
+        }
+
         addProfiledCall(ContextTracker.getFuncNameFor(Caller),
-                        ContextTracker.getFuncNameFor(Callee));
+                        ContextTracker.getFuncNameFor(Callee), Weight);
       }
     }
   }
 
-  iterator begin() { return Root.Callees.begin(); }
-  iterator end() { return Root.Callees.end(); }
+  iterator begin() { return Root.Edges.begin(); }
+  iterator end() { return Root.Edges.end(); }
   ProfiledCallGraphNode *getEntryNode() { return &Root; }
   void addProfiledFunction(StringRef Name) {
     if (!ProfiledFunctions.count(Name)) {
       // Link to synthetic root to make sure every node is reachable
       // from root. This does not affect SCC order.
       ProfiledFunctions[Name] = ProfiledCallGraphNode(Name);
-      Root.Callees.insert(&ProfiledFunctions[Name]);
+      Root.Edges.emplace(&Root, &ProfiledFunctions[Name], 0);
     }
   }
 
-  void addProfiledCall(StringRef CallerName, StringRef CalleeName) {
+private:
+  void addProfiledCall(StringRef CallerName, StringRef CalleeName,
+                       uint64_t Weight = 0) {
     assert(ProfiledFunctions.count(CallerName));
     auto CalleeIt = ProfiledFunctions.find(CalleeName);
-    if (CalleeIt == ProfiledFunctions.end()) {
+    if (CalleeIt == ProfiledFunctions.end())
       return;
+    ProfiledCallGraphEdge Edge(&ProfiledFunctions[CallerName],
+                               &CalleeIt->second, Weight);
+    auto &Edges = ProfiledFunctions[CallerName].Edges;
+    auto EdgeIt = Edges.find(Edge);
+    if (EdgeIt == Edges.end()) {
+      Edges.insert(Edge);
+    } else if (EdgeIt->Weight < Edge.Weight) {
+      // Replace existing call edges with same target but smaller weight.
+      Edges.erase(EdgeIt);
+      Edges.insert(Edge);
     }
-    ProfiledFunctions[CallerName].Callees.insert(&CalleeIt->second);
   }
 
   void addProfiledCalls(const FunctionSamples &Samples) {
@@ -107,20 +163,20 @@
     for (const auto &Sample : Samples.getBodySamples()) {
       for (const auto &Target : Sample.second.getCallTargets()) {
         addProfiledFunction(Target.first());
-        addProfiledCall(Samples.getFuncName(), Target.first());
+        addProfiledCall(Samples.getFuncName(), Target.first(), Target.second);
       }
     }
 
     for (const auto &CallsiteSamples : Samples.getCallsiteSamples()) {
       for (const auto &InlinedSamples : CallsiteSamples.second) {
         addProfiledFunction(InlinedSamples.first);
-        addProfiledCall(Samples.getFuncName(), InlinedSamples.first);
+        addProfiledCall(Samples.getFuncName(), InlinedSamples.first,
+                        InlinedSamples.second.getEntrySamples());
         addProfiledCalls(InlinedSamples.second);
       }
     }
   }
 
-private:
   ProfiledCallGraphNode Root;
   StringMap<ProfiledCallGraphNode> ProfiledFunctions;
 };
@@ -128,12 +184,14 @@
 } // end namespace sampleprof
 
 template <> struct GraphTraits<ProfiledCallGraphNode *> {
+  using NodeType = ProfiledCallGraphNode;
   using NodeRef = ProfiledCallGraphNode *;
-  using ChildIteratorType = std::set<ProfiledCallGraphNode *>::iterator;
+  using EdgeType = NodeType::edge;
+  using ChildIteratorType = NodeType::const_iterator;
 
   static NodeRef getEntryNode(NodeRef PCGN) { return PCGN; }
-  static ChildIteratorType child_begin(NodeRef N) { return N->Callees.begin(); }
-  static ChildIteratorType child_end(NodeRef N) { return N->Callees.end(); }
+  static ChildIteratorType child_begin(NodeRef N) { return N->Edges.begin(); }
+  static ChildIteratorType child_end(NodeRef N) { return N->Edges.end(); }
 };
 
 template <>
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index a961c47..5430731 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -173,6 +173,9 @@
                          cl::desc("Process functions in a top-down order "
                                   "defined by the profiled call graph when "
                                   "-sample-profile-top-down-load is on."));
+cl::opt<bool>
+    SortProfiledSCC("sort-profiled-scc-member", cl::init(true), cl::Hidden,
+                    cl::desc("Sort profiled recursion by edge weights."));
 
 static cl::opt<bool> ProfileSizeInline(
     "sample-profile-inline-size", cl::Hidden, cl::init(false),
@@ -1853,7 +1856,13 @@
     std::unique_ptr<ProfiledCallGraph> ProfiledCG = buildProfiledCallGraph(*CG);
     scc_iterator<ProfiledCallGraph *> CGI = scc_begin(ProfiledCG.get());
     while (!CGI.isAtEnd()) {
-      for (ProfiledCallGraphNode *Node : *CGI) {
+      auto Range = *CGI;
+      if (SortProfiledSCC) {
+        // Sort nodes in one SCC based on callsite hotness.
+        scc_member_iterator<ProfiledCallGraph *> SI(*CGI);
+        Range = *SI;
+      }
+      for (auto *Node : Range) {
         Function *F = SymbolMap.lookup(Node->Name);
         if (F && !F->isDeclaration() && F->hasFnAttribute("use-sample-profile"))
           FunctionOrderList.push_back(F);
diff --git a/llvm/test/Transforms/SampleProfile/Inputs/profile-context-order-scc.prof b/llvm/test/Transforms/SampleProfile/Inputs/profile-context-order-scc.prof
new file mode 100644
index 0000000..166d830
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/Inputs/profile-context-order-scc.prof
@@ -0,0 +1,43 @@
+[main:3 @ _Z5funcAi:1 @ _Z8funcLeafi]:1467299:11
+ 0: 6
+ 1: 6
+ 3: 287884
+ 15: 23
+[main:3.1 @ _Z5funcBi:1 @ _Z8funcLeafi]:500853:20
+ 0: 15
+ 1: 15
+ 3: 74946
+ 10: 23324
+ 15: 11
+[main]:154:0
+ 2: 12
+ 3: 18 _Z5funcAi:11
+ 3.1: 18 _Z5funcBi:19
+[external:12 @ main]:154:12
+ 2: 12
+ 3: 10 _Z5funcAi:7
+ 3.1: 10 _Z5funcBi:11
+[main:3.1 @ _Z5funcBi]:120:19
+ 0: 19
+ 1: 19 _Z8funcLeafi:20
+ 3: 12
+[externalA:17 @ _Z5funcBi]:120:3
+ 0: 3
+ 1: 3
+[external:10 @ _Z5funcBi]:120:10
+ 0: 10
+ 1: 10
+[main:3 @ _Z5funcAi]:99:11
+ 0: 10
+ 1: 10 _Z8funcLeafi:11
+ 2: 287864 _Z3fibi:315608
+ 3: 24
+[main:3 @ _Z5funcAi:2 @ _Z3fibi]:287864:315608
+ 0: 362839
+ 1: 6
+ 3: 287884
+[main:3 @ _Z5funcAi:1 @ _Z8funcLeafi:1 @ _Z5funcBi]:1467299:6
+ 0: 6
+ 1: 6
+ 3: 287884
+ 15: 23
\ No newline at end of file
diff --git a/llvm/test/Transforms/SampleProfile/profile-context-order.ll b/llvm/test/Transforms/SampleProfile/profile-context-order.ll
index 22d0f7c..ff70728 100644
--- a/llvm/test/Transforms/SampleProfile/profile-context-order.ll
+++ b/llvm/test/Transforms/SampleProfile/profile-context-order.ll
@@ -17,6 +17,14 @@
 ;; _Z3fibi inlined into _Z5funcAi.
 ; RUN: opt < %s -passes=sample-profile -use-profiled-call-graph=1 -sample-profile-file=%S/Inputs/profile-context-order.prof -S | FileCheck %s -check-prefix=ICALL-INLINE
 
+;; When a cycle is formed by profiled edges between _Z5funcBi and _Z8funcLeafi,
+;; the function processing order matters. Without considering call edge weights
+;; _Z8funcLeafi can be processed before _Z5funcBi, thus leads to suboptimal
+;; inlining.
+; RUN: opt < %s -passes=sample-profile -use-profiled-call-graph=1 -sort-profiled-scc-member=0 -sample-profile-file=%S/Inputs/profile-context-order-scc.prof -S | FileCheck %s -check-prefix=NOINLINEB
+; RUN: opt < %s -passes=sample-profile -use-profiled-call-graph=1 -sort-profiled-scc-member=1 -sample-profile-file=%S/Inputs/profile-context-order-scc.prof -S | FileCheck %s -check-prefix=INLINEB
+
+
 @factor = dso_local global i32 3, align 4, !dbg !0
 @fp = dso_local global i32 (i32)* null, align 8
 
@@ -47,6 +55,10 @@
 ; NOINLINE: call i32 @_Z8funcLeafi
 ; ICALL-INLINE: define dso_local i32 @_Z5funcAi
 ; ICALL-INLINE: call i32 @_Z3foo
+; INLINEB: define dso_local i32 @_Z5funcBi
+; INLINEB-NOT: call i32 @_Z8funcLeafi
+; NOINLINEB: define dso_local i32 @_Z5funcBi
+; NOINLINEB: call i32 @_Z8funcLeafi
 define dso_local i32 @_Z5funcAi(i32 %x) local_unnamed_addr #0 !dbg !40 {
 entry:
   %add = add nsw i32 %x, 100000, !dbg !44
diff --git a/llvm/tools/llvm-profgen/CSPreInliner.cpp b/llvm/tools/llvm-profgen/CSPreInliner.cpp
index 1928da8..435c1f9 100644
--- a/llvm/tools/llvm-profgen/CSPreInliner.cpp
+++ b/llvm/tools/llvm-profgen/CSPreInliner.cpp
@@ -38,6 +38,7 @@
 extern cl::opt<int> ProfileInlineGrowthLimit;
 extern cl::opt<int> ProfileInlineLimitMin;
 extern cl::opt<int> ProfileInlineLimitMax;
+extern cl::opt<bool> SortProfiledSCC;
 
 cl::opt<bool> EnableCSPreInliner(
     "csspgo-preinliner", cl::Hidden, cl::init(false),
@@ -70,7 +71,13 @@
   // by building up SCC and reversing SCC order.
   scc_iterator<ProfiledCallGraph *> I = scc_begin(&ProfiledCG);
   while (!I.isAtEnd()) {
-    for (ProfiledCallGraphNode *Node : *I) {
+    auto Range = *I;
+    if (SortProfiledSCC) {
+      // Sort nodes in one SCC based on callsite hotness.
+      scc_member_iterator<ProfiledCallGraph *> SI(*I);
+      Range = *SI;
+    }
+    for (auto *Node : Range) {
       if (Node != ProfiledCG.getEntryNode())
         Order.push_back(Node->Name);
     }