[ctxprof] Capture sampling info for context roots (#131201)

When we collect a contextual profile, we sample the threads entering its root and only collect on one at a time (see `ContextRoot::Taken`). If we want to compare profiles between contextual profiles, and/or flat profiles, we have a problem: we don't know how to compare the counter values relative to each other. To that end, we add `ContextRoot::TotalEntries`, which is incremented every time a root is entered and serves as multiplier for the counter values collected under that root.

We expose this in the profile and leave the normalization to the user of the profile, for a few reasons:

* it's only needed if reasoning about all profiles in aggregate.
* the goal, in compiler_rt, is to flush out the profile as quickly as possible, and performing multiplications adds an overhead that may not even be necessary if the consumer of the profile doesn't care about combining profiles
* the information itself may be interesting as an indication of relative sampling of various contexts.
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
index 0fc4883..55962df 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
+++ b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
@@ -120,7 +120,8 @@
 class ProfileWriter {
 public:
   virtual void startContextSection() = 0;
-  virtual void writeContextual(const ctx_profile::ContextNode &RootNode) = 0;
+  virtual void writeContextual(const ctx_profile::ContextNode &RootNode,
+                               uint64_t TotalRootEntryCount) = 0;
   virtual void endContextSection() = 0;
 
   virtual void startFlatSection() = 0;
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
index d7ec8fd..1c2cad1 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
@@ -340,6 +340,9 @@
     ContextRoot *Root, GUID Guid, uint32_t Counters,
     uint32_t Callsites) SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
   IsUnderContext = true;
+  __sanitizer::atomic_fetch_add(&Root->TotalEntries, 1,
+                                __sanitizer::memory_order_relaxed);
+
   if (!Root->FirstMemBlock) {
     setupContext(Root, Guid, Counters, Callsites);
   }
@@ -374,6 +377,7 @@
       ++NumMemUnits;
 
     resetContextNode(*Root->FirstNode);
+    __sanitizer::atomic_store_relaxed(&Root->TotalEntries, 0);
   }
   __sanitizer::atomic_store_relaxed(&ProfilingStarted, true);
   __sanitizer::Printf("[ctxprof] Initial NumMemUnits: %zu \n", NumMemUnits);
@@ -393,7 +397,8 @@
       __sanitizer::Printf("[ctxprof] Contextual Profile is %s\n", "invalid");
       return false;
     }
-    Writer.writeContextual(*Root->FirstNode);
+    Writer.writeContextual(*Root->FirstNode, __sanitizer::atomic_load_relaxed(
+                                                 &Root->TotalEntries));
   }
   Writer.endContextSection();
   Writer.startFlatSection();
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
index ab6df6d..72cc60b 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
@@ -80,6 +80,10 @@
   ContextNode *FirstNode = nullptr;
   Arena *FirstMemBlock = nullptr;
   Arena *CurrentMem = nullptr;
+
+  // Count the number of entries - regardless if we could take the `Taken` mutex
+  ::__sanitizer::atomic_uint64_t TotalEntries = {};
+
   // This is init-ed by the static zero initializer in LLVM.
   // Taken is used to ensure only one thread traverses the contextual graph -
   // either to read it or to write it. On server side, the same entrypoint will
diff --git a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
index f837424..62c7f53 100644
--- a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
+++ b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
@@ -238,7 +238,9 @@
     TestProfileWriter(ContextRoot *Root, size_t Entries)
         : Root(Root), Entries(Entries) {}
 
-    void writeContextual(const ContextNode &Node) override {
+    void writeContextual(const ContextNode &Node,
+                         uint64_t TotalRootEntryCount) override {
+      EXPECT_EQ(TotalRootEntryCount, Entries);
       EXPECT_EQ(EnteredSectionCount, 1);
       EXPECT_EQ(ExitedSectionCount, 0);
       EXPECT_FALSE(Root->Taken.TryLock());
diff --git a/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp b/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
index bf33b44..319f17d 100644
--- a/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
+++ b/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
@@ -84,7 +84,10 @@
     std::cout << "Exited Context Section" << std::endl;
   }
 
-  void writeContextual(const ContextNode &RootNode) override {
+  void writeContextual(const ContextNode &RootNode,
+                       uint64_t EntryCount) override {
+    std::cout << "Entering Root " << RootNode.guid()
+              << " with total entry count " << EntryCount << std::endl;
     printProfile(RootNode, "", "");
   }
 
@@ -115,6 +118,7 @@
 // The second context is in the loop. We expect 2 entries and each of the
 // branches would be taken once, so the second counter is 1.
 // CHECK-NEXT: Entered Context Section
+// CHECK-NEXT: Entering Root 8657661246551306189 with total entry count 1
 // CHECK-NEXT: Guid: 8657661246551306189
 // CHECK-NEXT: Entries: 1
 // CHECK-NEXT: 2 counters and 3 callsites
diff --git a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
index 0fc4883..55962df 100644
--- a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
+++ b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
@@ -120,7 +120,8 @@
 class ProfileWriter {
 public:
   virtual void startContextSection() = 0;
-  virtual void writeContextual(const ctx_profile::ContextNode &RootNode) = 0;
+  virtual void writeContextual(const ctx_profile::ContextNode &RootNode,
+                               uint64_t TotalRootEntryCount) = 0;
   virtual void endContextSection() = 0;
 
   virtual void startFlatSection() = 0;
diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
index 4b0c944..65be543 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
@@ -92,10 +92,13 @@
 
   GlobalValue::GUID GUID = 0;
   SmallVector<uint64_t, 16> Counters;
+  const std::optional<uint64_t> RootEntryCount;
   CallsiteMapTy Callsites;
 
-  PGOCtxProfContext(GlobalValue::GUID G, SmallVectorImpl<uint64_t> &&Counters)
-      : GUID(G), Counters(std::move(Counters)) {}
+  PGOCtxProfContext(GlobalValue::GUID G, SmallVectorImpl<uint64_t> &&Counters,
+                    std::optional<uint64_t> RootEntryCount = std::nullopt)
+      : GUID(G), Counters(std::move(Counters)), RootEntryCount(RootEntryCount) {
+  }
 
   Expected<PGOCtxProfContext &>
   getOrEmplace(uint32_t Index, GlobalValue::GUID G,
@@ -115,6 +118,9 @@
   const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
   SmallVectorImpl<uint64_t> &counters() { return Counters; }
 
+  bool isRoot() const { return RootEntryCount.has_value(); }
+  uint64_t getTotalRootEntryCount() const { return RootEntryCount.value(); }
+
   uint64_t getEntrycount() const {
     assert(!Counters.empty() &&
            "Functions are expected to have at their entry BB instrumented, so "
diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h b/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
index c5a724d..b2bb8fe 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
@@ -19,7 +19,14 @@
 #include "llvm/ProfileData/CtxInstrContextNode.h"
 
 namespace llvm {
-enum PGOCtxProfileRecords { Invalid = 0, Version, Guid, CalleeIndex, Counters };
+enum PGOCtxProfileRecords {
+  Invalid = 0,
+  Version,
+  Guid,
+  CallsiteIndex,
+  Counters,
+  TotalRootEntryCount
+};
 
 enum PGOCtxProfileBlockIDs {
   FIRST_VALID = bitc::FIRST_APPLICATION_BLOCKID,
@@ -73,9 +80,11 @@
   const bool IncludeEmpty;
 
   void writeGuid(ctx_profile::GUID Guid);
+  void writeCallsiteIndex(uint32_t Index);
+  void writeRootEntryCount(uint64_t EntryCount);
   void writeCounters(ArrayRef<uint64_t> Counters);
-  void writeImpl(std::optional<uint32_t> CallerIndex,
-                 const ctx_profile::ContextNode &Node);
+  void writeNode(uint32_t CallerIndex, const ctx_profile::ContextNode &Node);
+  void writeSubcontexts(const ctx_profile::ContextNode &Node);
 
 public:
   PGOCtxProfileWriter(raw_ostream &Out,
@@ -84,7 +93,8 @@
   ~PGOCtxProfileWriter() { Writer.ExitBlock(); }
 
   void startContextSection() override;
-  void writeContextual(const ctx_profile::ContextNode &RootNode) override;
+  void writeContextual(const ctx_profile::ContextNode &RootNode,
+                       uint64_t TotalRootEntryCount) override;
   void endContextSection() override;
 
   void startFlatSection() override;
@@ -94,7 +104,7 @@
 
   // constants used in writing which a reader may find useful.
   static constexpr unsigned CodeLen = 2;
-  static constexpr uint32_t CurrentVersion = 2;
+  static constexpr uint32_t CurrentVersion = 3;
   static constexpr unsigned VBREncodingBits = 6;
   static constexpr StringRef ContainerMagic = "CTXP";
 };
diff --git a/llvm/lib/ProfileData/PGOCtxProfReader.cpp b/llvm/lib/ProfileData/PGOCtxProfReader.cpp
index 5cc4c94..f53f295 100644
--- a/llvm/lib/ProfileData/PGOCtxProfReader.cpp
+++ b/llvm/lib/ProfileData/PGOCtxProfReader.cpp
@@ -96,16 +96,19 @@
   std::optional<ctx_profile::GUID> Guid;
   std::optional<SmallVector<uint64_t, 16>> Counters;
   std::optional<uint32_t> CallsiteIndex;
+  std::optional<uint64_t> TotalEntryCount;
 
   SmallVector<uint64_t, 1> RecordValues;
 
   const bool ExpectIndex = Kind == PGOCtxProfileBlockIDs::ContextNodeBlockID;
+  const bool IsRoot = Kind == PGOCtxProfileBlockIDs::ContextRootBlockID;
   // We don't prescribe the order in which the records come in, and we are ok
   // if other unsupported records appear. We seek in the current subblock until
   // we get all we know.
   auto GotAllWeNeed = [&]() {
     return Guid.has_value() && Counters.has_value() &&
-           (!ExpectIndex || CallsiteIndex.has_value());
+           (!ExpectIndex || CallsiteIndex.has_value()) &&
+           (!IsRoot || TotalEntryCount.has_value());
   };
   while (!GotAllWeNeed()) {
     RecordValues.clear();
@@ -127,13 +130,21 @@
         return wrongValue("Empty counters. At least the entry counter (one "
                           "value) was expected");
       break;
-    case PGOCtxProfileRecords::CalleeIndex:
+    case PGOCtxProfileRecords::CallsiteIndex:
       if (!ExpectIndex)
         return wrongValue("The root context should not have a callee index");
       if (RecordValues.size() != 1)
         return wrongValue("The callee index should have exactly one value");
       CallsiteIndex = RecordValues[0];
       break;
+    case PGOCtxProfileRecords::TotalRootEntryCount:
+      if (!IsRoot)
+        return wrongValue("Non-root has a total entry count record");
+      if (RecordValues.size() != 1)
+        return wrongValue(
+            "The root total entry count record should have exactly one value");
+      TotalEntryCount = RecordValues[0];
+      break;
     default:
       // OK if we see records we do not understand, like records (profile
       // components) introduced later.
@@ -141,7 +152,7 @@
     }
   }
 
-  PGOCtxProfContext Ret(*Guid, std::move(*Counters));
+  PGOCtxProfContext Ret(*Guid, std::move(*Counters), TotalEntryCount);
 
   while (canEnterBlockWithID(PGOCtxProfileBlockIDs::ContextNodeBlockID)) {
     EXPECT_OR_RET(SC, readProfile(PGOCtxProfileBlockIDs::ContextNodeBlockID));
@@ -278,7 +289,8 @@
 
 void toYaml(yaml::Output &Out, GlobalValue::GUID Guid,
             const SmallVectorImpl<uint64_t> &Counters,
-            const PGOCtxProfContext::CallsiteMapTy &Callsites) {
+            const PGOCtxProfContext::CallsiteMapTy &Callsites,
+            std::optional<uint64_t> TotalRootEntryCount = std::nullopt) {
   yaml::EmptyContext Empty;
   Out.beginMapping();
   void *SaveInfo = nullptr;
@@ -289,6 +301,11 @@
     yaml::yamlize(Out, Guid, true, Empty);
     Out.postflightKey(nullptr);
   }
+  if (TotalRootEntryCount) {
+    Out.preflightKey("TotalRootEntryCount", true, false, UseDefault, SaveInfo);
+    yaml::yamlize(Out, *TotalRootEntryCount, true, Empty);
+    Out.postflightKey(nullptr);
+  }
   {
     Out.preflightKey("Counters", true, false, UseDefault, SaveInfo);
     Out.beginFlowSequence();
@@ -308,8 +325,13 @@
   }
   Out.endMapping();
 }
+
 void toYaml(yaml::Output &Out, const PGOCtxProfContext &Ctx) {
-  toYaml(Out, Ctx.guid(), Ctx.counters(), Ctx.callsites());
+  if (Ctx.isRoot())
+    toYaml(Out, Ctx.guid(), Ctx.counters(), Ctx.callsites(),
+           Ctx.getTotalRootEntryCount());
+  else
+    toYaml(Out, Ctx.guid(), Ctx.counters(), Ctx.callsites());
 }
 
 } // namespace
diff --git a/llvm/lib/ProfileData/PGOCtxProfWriter.cpp b/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
index e906836..9108426 100644
--- a/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
+++ b/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
@@ -55,10 +55,12 @@
     DescribeBlock(PGOCtxProfileBlockIDs::ContextsSectionBlockID, "Contexts");
     DescribeBlock(PGOCtxProfileBlockIDs::ContextRootBlockID, "Root");
     DescribeRecord(PGOCtxProfileRecords::Guid, "GUID");
+    DescribeRecord(PGOCtxProfileRecords::TotalRootEntryCount,
+                   "TotalRootEntryCount");
     DescribeRecord(PGOCtxProfileRecords::Counters, "Counters");
     DescribeBlock(PGOCtxProfileBlockIDs::ContextNodeBlockID, "Context");
     DescribeRecord(PGOCtxProfileRecords::Guid, "GUID");
-    DescribeRecord(PGOCtxProfileRecords::CalleeIndex, "CalleeIndex");
+    DescribeRecord(PGOCtxProfileRecords::CallsiteIndex, "CalleeIndex");
     DescribeRecord(PGOCtxProfileRecords::Counters, "Counters");
     DescribeBlock(PGOCtxProfileBlockIDs::FlatProfilesSectionBlockID,
                   "FlatProfiles");
@@ -85,29 +87,39 @@
   Writer.EmitRecord(PGOCtxProfileRecords::Guid, SmallVector<uint64_t, 1>{Guid});
 }
 
+void PGOCtxProfileWriter::writeCallsiteIndex(uint32_t CallsiteIndex) {
+  Writer.EmitRecord(PGOCtxProfileRecords::CallsiteIndex,
+                    SmallVector<uint64_t, 1>{CallsiteIndex});
+}
+
+void PGOCtxProfileWriter::writeRootEntryCount(uint64_t TotalRootEntryCount) {
+  Writer.EmitRecord(PGOCtxProfileRecords::TotalRootEntryCount,
+                    SmallVector<uint64_t, 1>{TotalRootEntryCount});
+}
+
 // recursively write all the subcontexts. We do need to traverse depth first to
 // model the context->subcontext implicitly, and since this captures call
 // stacks, we don't really need to be worried about stack overflow and we can
 // keep the implementation simple.
-void PGOCtxProfileWriter::writeImpl(std::optional<uint32_t> CallerIndex,
+void PGOCtxProfileWriter::writeNode(uint32_t CallsiteIndex,
                                     const ContextNode &Node) {
   // A node with no counters is an error. We don't expect this to happen from
   // the runtime, rather, this is interesting for testing the reader.
   if (!IncludeEmpty && (Node.counters_size() > 0 && Node.entrycount() == 0))
     return;
-  Writer.EnterSubblock(CallerIndex ? PGOCtxProfileBlockIDs::ContextNodeBlockID
-                                   : PGOCtxProfileBlockIDs::ContextRootBlockID,
-                       CodeLen);
+  Writer.EnterSubblock(PGOCtxProfileBlockIDs::ContextNodeBlockID, CodeLen);
   writeGuid(Node.guid());
-  if (CallerIndex)
-    Writer.EmitRecord(PGOCtxProfileRecords::CalleeIndex,
-                      SmallVector<uint64_t, 1>{*CallerIndex});
+  writeCallsiteIndex(CallsiteIndex);
   writeCounters({Node.counters(), Node.counters_size()});
+  writeSubcontexts(Node);
+  Writer.ExitBlock();
+}
+
+void PGOCtxProfileWriter::writeSubcontexts(const ContextNode &Node) {
   for (uint32_t I = 0U; I < Node.callsites_size(); ++I)
     for (const auto *Subcontext = Node.subContexts()[I]; Subcontext;
          Subcontext = Subcontext->next())
-      writeImpl(I, *Subcontext);
-  Writer.ExitBlock();
+      writeNode(I, *Subcontext);
 }
 
 void PGOCtxProfileWriter::startContextSection() {
@@ -122,8 +134,17 @@
 void PGOCtxProfileWriter::endContextSection() { Writer.ExitBlock(); }
 void PGOCtxProfileWriter::endFlatSection() { Writer.ExitBlock(); }
 
-void PGOCtxProfileWriter::writeContextual(const ContextNode &RootNode) {
-  writeImpl(std::nullopt, RootNode);
+void PGOCtxProfileWriter::writeContextual(const ContextNode &RootNode,
+                                          uint64_t TotalRootEntryCount) {
+  if (!IncludeEmpty && (!TotalRootEntryCount || (RootNode.counters_size() > 0 &&
+                                                 RootNode.entrycount() == 0)))
+    return;
+  Writer.EnterSubblock(PGOCtxProfileBlockIDs::ContextRootBlockID, CodeLen);
+  writeGuid(RootNode.guid());
+  writeRootEntryCount(TotalRootEntryCount);
+  writeCounters({RootNode.counters(), RootNode.counters_size()});
+  writeSubcontexts(RootNode);
+  Writer.ExitBlock();
 }
 
 void PGOCtxProfileWriter::writeFlat(ctx_profile::GUID Guid,
@@ -144,11 +165,15 @@
   std::vector<std::vector<SerializableCtxRepresentation>> Callsites;
 };
 
+struct SerializableRootRepresentation : public SerializableCtxRepresentation {
+  uint64_t TotalRootEntryCount = 0;
+};
+
 using SerializableFlatProfileRepresentation =
     std::pair<ctx_profile::GUID, std::vector<uint64_t>>;
 
 struct SerializableProfileRepresentation {
-  std::vector<SerializableCtxRepresentation> Contexts;
+  std::vector<SerializableRootRepresentation> Contexts;
   std::vector<SerializableFlatProfileRepresentation> FlatProfiles;
 };
 
@@ -189,6 +214,7 @@
 
 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializableCtxRepresentation)
 LLVM_YAML_IS_SEQUENCE_VECTOR(std::vector<SerializableCtxRepresentation>)
+LLVM_YAML_IS_SEQUENCE_VECTOR(SerializableRootRepresentation)
 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializableFlatProfileRepresentation)
 template <> struct yaml::MappingTraits<SerializableCtxRepresentation> {
   static void mapping(yaml::IO &IO, SerializableCtxRepresentation &SCR) {
@@ -198,6 +224,13 @@
   }
 };
 
+template <> struct yaml::MappingTraits<SerializableRootRepresentation> {
+  static void mapping(yaml::IO &IO, SerializableRootRepresentation &R) {
+    yaml::MappingTraits<SerializableCtxRepresentation>::mapping(IO, R);
+    IO.mapRequired("TotalRootEntryCount", R.TotalRootEntryCount);
+  }
+};
+
 template <> struct yaml::MappingTraits<SerializableProfileRepresentation> {
   static void mapping(yaml::IO &IO, SerializableProfileRepresentation &SPR) {
     IO.mapOptional("Contexts", SPR.Contexts);
@@ -232,7 +265,7 @@
       if (!TopList)
         return createStringError(
             "Unexpected error converting internal structure to ctx profile");
-      Writer.writeContextual(*TopList);
+      Writer.writeContextual(*TopList, DC.TotalRootEntryCount);
     }
     Writer.endContextSection();
   }
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
index 9f2b2d6..3c30dd6 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -120,6 +120,7 @@
                                           PointerTy,          /*FirstNode*/
                                           PointerTy,          /*FirstMemBlock*/
                                           PointerTy,          /*CurrentMem*/
+                                          I64Ty,              /*TotalEntries*/
                                           SanitizerMutexType, /*Taken*/
                                       });
   FunctionDataTy =
diff --git a/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll b/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll
index 20eaf59..c7b325b 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll
@@ -61,6 +61,7 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 4909520559318251808
+    TotalRootEntryCount: 100
     Counters: [100, 40]
     Callsites: -
                 - Guid: 11872291593386833696
diff --git a/llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll b/llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll
index eb697b6..b10eb6a 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll
@@ -41,6 +41,7 @@
 ;--- profile_ok.yaml
 Contexts:
   - Guid: 1234 
+    TotalRootEntryCount: 2
     Counters: [2, 2, 1, 2]
 
 ;--- message_pump.ll
@@ -64,6 +65,7 @@
 ;--- profile_pump.yaml
 Contexts:
   - Guid: 1234
+    TotalRootEntryCount: 2
     Counters: [2, 10, 0]
 
 ;--- unreachable.ll
@@ -88,4 +90,5 @@
 ;--- profile_unreachable.yaml
 Contexts:
   - Guid: 1234
+    TotalRootEntryCount: 2
     Counters: [2, 1, 1, 2]
diff --git a/llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll b/llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll
index 18f85e6..deea666 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll
@@ -48,6 +48,7 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 4000
+    TotalRootEntryCount: 10
     Counters: [10]
     Callsites:  -
                   - Guid: 3000
diff --git a/llvm/test/Analysis/CtxProfAnalysis/flatten-zero-path.ll b/llvm/test/Analysis/CtxProfAnalysis/flatten-zero-path.ll
index 7db4ea2..4c638b6 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/flatten-zero-path.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/flatten-zero-path.ll
@@ -54,4 +54,5 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 1234
+    TotalRootEntryCount: 12
     Counters: [6,0,0,0]
diff --git a/llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll b/llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll
index 054eef4..63abdd8 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll
@@ -68,6 +68,7 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 10507721908651011566
+    TotalRootEntryCount: 10
     Counters: [1]
     Callsites:  -
                   - Guid:  2072045998141807037
@@ -92,6 +93,7 @@
 
 Contexts:
   - Guid:            10507721908651011566
+    TotalRootEntryCount: 10
     Counters:        [ 1 ]
     Callsites:
       - - Guid:            2072045998141807037
diff --git a/llvm/test/Analysis/CtxProfAnalysis/handle-select.ll b/llvm/test/Analysis/CtxProfAnalysis/handle-select.ll
index d1f729b..dfbc5c9 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/handle-select.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/handle-select.ll
@@ -75,6 +75,7 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 1234
+    TotalRootEntryCount: 100
     Counters: [10, 4]
     Callsites:  -
                   - Guid: 5678
diff --git a/llvm/test/Analysis/CtxProfAnalysis/inline.ll b/llvm/test/Analysis/CtxProfAnalysis/inline.ll
index 31f789b..836ec8b 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/inline.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/inline.ll
@@ -98,6 +98,7 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 1000
+    TotalRootEntryCount: 24
     Counters: [10, 2, 8]
     Callsites:  -
                   - Guid: 1001
@@ -115,6 +116,7 @@
 
 Contexts:
   - Guid:            1000
+    TotalRootEntryCount: 24
     Counters:        [ 10, 2, 8, 100 ]
     Callsites:
       - [  ]
diff --git a/llvm/test/Analysis/CtxProfAnalysis/load-unapplicable.ll b/llvm/test/Analysis/CtxProfAnalysis/load-unapplicable.ll
index 6e142d9..9c35878 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/load-unapplicable.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/load-unapplicable.ll
@@ -18,10 +18,13 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 12341
+    TotalRootEntryCount: 24
     Counters: [9]
   - Guid: 1000
+    TotalRootEntryCount: 25
     Counters: [5]
   - Guid: 34234
+    TotalRootEntryCount: 2
     Counters: [1]
     Callsites:  -
                   - Guid: 1000
diff --git a/llvm/test/Analysis/CtxProfAnalysis/load.ll b/llvm/test/Analysis/CtxProfAnalysis/load.ll
index 0163225..6091a99 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/load.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/load.ll
@@ -26,10 +26,13 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 12341
+    TotalRootEntryCount: 90
     Counters: [9]
   - Guid: 12074870348631550642
+    TotalRootEntryCount: 24
     Counters: [5]
   - Guid: 11872291593386833696
+    TotalRootEntryCount: 4
     Counters: [1]
     Callsites:  -
                   - Guid: 728453322856651412
@@ -44,11 +47,13 @@
 
 Contexts:
   - Guid:            11872291593386833696
+    TotalRootEntryCount: 4
     Counters:        [ 1 ]
     Callsites:
       - - Guid:            728453322856651412
           Counters:        [ 6, 7 ]
   - Guid:            12074870348631550642
+    TotalRootEntryCount: 24
     Counters:        [ 5 ]
 
 Flat Profile:
diff --git a/llvm/test/ThinLTO/X86/ctxprof.ll b/llvm/test/ThinLTO/X86/ctxprof.ll
index a4bc792..fed7a83 100644
--- a/llvm/test/ThinLTO/X86/ctxprof.ll
+++ b/llvm/test/ThinLTO/X86/ctxprof.ll
@@ -53,8 +53,8 @@
 ; RUN: opt -module-summary -passes=assign-guid,ctx-instr-gen %t/m2.ll -o %t/m2-instr.bc
 ;
 ; RUN: echo '{"Contexts": [ \
-; RUN:        {"Guid": 6019442868614718803, "Counters": [1], "Callsites": [[{"Guid": 15593096274670919754, "Counters": [1]}]]}, \
-; RUN:        {"Guid": 15593096274670919754, "Counters": [1], "Callsites": [[{"Guid": 6019442868614718803, "Counters": [1]}]]} \
+; RUN:        {"Guid": 6019442868614718803, "TotalRootEntryCount": 5, "Counters": [1], "Callsites": [[{"Guid": 15593096274670919754, "Counters": [1]}]]}, \
+; RUN:        {"Guid": 15593096274670919754, "TotalRootEntryCount": 2, "Counters": [1], "Callsites": [[{"Guid": 6019442868614718803, "Counters": [1]}]]} \
 ; RUN:  ]}' > %t_exp/ctxprof.yaml
 ; RUN: llvm-ctxprof-util fromYAML --input %t_exp/ctxprof.yaml --output %t_exp/ctxprof.bitstream
 ; RUN: llvm-lto2 run %t/m1-instr.bc %t/m2-instr.bc \
diff --git a/llvm/test/Transforms/EliminateAvailableExternally/transform-to-local.ll b/llvm/test/Transforms/EliminateAvailableExternally/transform-to-local.ll
index 8d0fe5f..b6465f4 100644
--- a/llvm/test/Transforms/EliminateAvailableExternally/transform-to-local.ll
+++ b/llvm/test/Transforms/EliminateAvailableExternally/transform-to-local.ll
@@ -1,7 +1,7 @@
 ; REQUIRES: asserts
 ; RUN: opt -passes=elim-avail-extern -avail-extern-to-local -stats -S 2>&1 < %s | FileCheck %s
 ;
-; RUN: echo '{"Contexts": [{"Guid":1234, "Counters": [1]}]}' | llvm-ctxprof-util fromYAML --input=- --output=%t_profile.ctxprofdata
+; RUN: echo '{"Contexts": [{"Guid":1234, "TotalRootEntryCount": 5, "Counters": [1]}]}' | llvm-ctxprof-util fromYAML --input=- --output=%t_profile.ctxprofdata
 ;
 ; Because we pass a contextual profile with a root defined in this module, we expect the outcome to be the same as-if
 ; we passed -avail-extern-to-local, i.e. available_externally don't get elided and instead get converted to local linkage
@@ -9,7 +9,7 @@
 
 ; If the profile doesn't apply to this module, available_externally won't get converted to internal linkage, and will be
 ; removed instead.
-; RUN: echo '{"Contexts": [{"Guid":5678, "Counters": [1]}]}' | llvm-ctxprof-util fromYAML --input=- --output=%t_profile_bad.ctxprofdata
+; RUN: echo '{"Contexts": [{"Guid":5678, "TotalRootEntryCount": 3, "Counters": [1]}]}' | llvm-ctxprof-util fromYAML --input=- --output=%t_profile_bad.ctxprofdata
 ; RUN: opt -passes='assign-guid,require<ctx-prof-analysis>,elim-avail-extern' -use-ctx-profile=%t_profile_bad.ctxprofdata -stats -S 2>&1 < %s | FileCheck %s --check-prefix=NOOP
 
 declare void @call_out(ptr %fct)
diff --git a/llvm/test/tools/llvm-ctxprof-util/Inputs/invalid-no-entrycount.yaml b/llvm/test/tools/llvm-ctxprof-util/Inputs/invalid-no-entrycount.yaml
new file mode 100644
index 0000000..64e18d0
--- /dev/null
+++ b/llvm/test/tools/llvm-ctxprof-util/Inputs/invalid-no-entrycount.yaml
@@ -0,0 +1,3 @@
+Contexts:
+  - Guid:            1000
+    Counters:        [ 1, 2, 3 ]
diff --git a/llvm/test/tools/llvm-ctxprof-util/Inputs/valid-ctx-only.yaml b/llvm/test/tools/llvm-ctxprof-util/Inputs/valid-ctx-only.yaml
index 0de489d..5e12955 100644
--- a/llvm/test/tools/llvm-ctxprof-util/Inputs/valid-ctx-only.yaml
+++ b/llvm/test/tools/llvm-ctxprof-util/Inputs/valid-ctx-only.yaml
@@ -1,6 +1,7 @@
 
 Contexts:
   - Guid:            1000
+    TotalRootEntryCount: 5
     Counters:        [ 1, 2, 3 ]
     Callsites:
       - [  ]
@@ -11,4 +12,5 @@
       - - Guid:            3000
           Counters:        [ 40, 50 ]
   - Guid:            18446744073709551612
+    TotalRootEntryCount: 45
     Counters:        [ 5, 9, 10 ]
diff --git a/llvm/test/tools/llvm-ctxprof-util/Inputs/valid-flat-first.yaml b/llvm/test/tools/llvm-ctxprof-util/Inputs/valid-flat-first.yaml
index 5567faa..4aef351 100644
--- a/llvm/test/tools/llvm-ctxprof-util/Inputs/valid-flat-first.yaml
+++ b/llvm/test/tools/llvm-ctxprof-util/Inputs/valid-flat-first.yaml
@@ -6,6 +6,7 @@
     Counters:        [ 1 ]
 Contexts:
   - Guid:            1000
+    TotalRootEntryCount: 5
     Counters:        [ 1, 2, 3 ]
     Callsites:
       - [  ]
@@ -16,4 +17,5 @@
       - - Guid:            3000
           Counters:        [ 40, 50 ]
   - Guid:            18446744073709551612
+    TotalRootEntryCount: 45
     Counters:        [ 5, 9, 10 ]
diff --git a/llvm/test/tools/llvm-ctxprof-util/Inputs/valid.yaml b/llvm/test/tools/llvm-ctxprof-util/Inputs/valid.yaml
index 1541b0d..22e20af 100644
--- a/llvm/test/tools/llvm-ctxprof-util/Inputs/valid.yaml
+++ b/llvm/test/tools/llvm-ctxprof-util/Inputs/valid.yaml
@@ -1,6 +1,7 @@
 
 Contexts:
   - Guid:            1000
+    TotalRootEntryCount: 5
     Counters:        [ 1, 2, 3 ]
     Callsites:
       - [  ]
@@ -11,6 +12,7 @@
       - - Guid:            3000
           Counters:        [ 40, 50 ]
   - Guid:            18446744073709551612
+    TotalRootEntryCount: 45
     Counters:        [ 5, 9, 10 ]
 FlatProfiles:
   - Guid:            1234
diff --git a/llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util-negative.test b/llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util-negative.test
index f312f50..511d4e6 100644
--- a/llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util-negative.test
+++ b/llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util-negative.test
@@ -10,6 +10,7 @@
 ; RUN: not llvm-ctxprof-util fromYAML --input %S/Inputs/invalid-no-counters.yaml 2>&1 | FileCheck %s --check-prefix=NO_COUNTERS
 ; RUN: not llvm-ctxprof-util fromYAML --input %S/Inputs/invalid-bad-subctx.yaml 2>&1 | FileCheck %s --check-prefix=BAD_SUBCTX
 ; RUN: not llvm-ctxprof-util fromYAML --input %S/Inputs/invalid-flat.yaml 2>&1 | FileCheck %s --check-prefix=BAD_FLAT
+; RUN: not llvm-ctxprof-util fromYAML --input %S/Inputs/invalid-no-entrycount.yaml 2>&1 | FileCheck %s --check-prefix=BAD_NOENTRYCOUNT
 ; RUN: rm -rf %t
 ; RUN: not llvm-ctxprof-util fromYAML --input %S/Inputs/valid.yaml --output %t/output.bitstream 2>&1 | FileCheck %s --check-prefix=NO_DIR
 
@@ -23,4 +24,5 @@
 ; NO_COUNTERS: YAML:2:5: error: missing required key 'Counters'
 ; BAD_SUBCTX: YAML:4:18: error: not a sequence
 ; BAD_FLAT: YAML:2:5: error: missing required key 'Counters'
+; BAD_NOENTRYCOUNT: YAML:2:5: error: missing required key 'TotalRootEntryCount'
 ; NO_DIR: failed to open output
diff --git a/llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util.test b/llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util.test
index a9e3503..937bf89 100644
--- a/llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util.test
+++ b/llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util.test
@@ -33,15 +33,16 @@
 
 ; EMPTY: <BLOCKINFO_BLOCK/>
 ; EMPTY-NEXT: <Metadata NumWords=1 BlockCodeSize=2>
-; EMPTY-NEXT:   <Version op0=2/>
+; EMPTY-NEXT:   <Version op0=3/>
 ; EMPTY-NEXT: </Metadata>
 
 ; VALID:      <BLOCKINFO_BLOCK/>
-; VALID-NEXT: <Metadata NumWords=45 BlockCodeSize=2>
-; VALID-NEXT:   <Version op0=2/>
-; VALID-NEXT:   <Contexts NumWords=29 BlockCodeSize=2>
+; VALID-NEXT: <Metadata NumWords=46 BlockCodeSize=2>
+; VALID-NEXT:   <Version op0=3/>
+; VALID-NEXT:   <Contexts NumWords=30 BlockCodeSize=2>
 ; VALID-NEXT:     <Root NumWords=20 BlockCodeSize=2>
 ; VALID-NEXT:       <GUID op0=1000/>
+; VALID-NEXT:       <TotalRootEntryCount op0=5/>
 ; VALID-NEXT:       <Counters op0=1 op1=2 op2=3/>
 ; VALID-NEXT:       <Context NumWords=5 BlockCodeSize=2>
 ; VALID-NEXT:         <GUID op0=-3/>
@@ -59,8 +60,9 @@
 ; VALID-NEXT:         <Counters op0=40 op1=50/>
 ; VALID-NEXT:       </Context>
 ; VALID-NEXT:     </Root>
-; VALID-NEXT:     <Root NumWords=4 BlockCodeSize=2>
+; VALID-NEXT:     <Root NumWords=5 BlockCodeSize=2>
 ; VALID-NEXT:       <GUID op0=-4/>
+; VALID-NEXT:       <TotalRootEntryCount op0=45/>
 ; VALID-NEXT:       <Counters op0=5 op1=9 op2=10/>
 ; VALID-NEXT:     </Root>
 ; VALID-NEXT:   </Contexts>
diff --git a/llvm/unittests/ProfileData/PGOCtxProfReaderWriterTest.cpp b/llvm/unittests/ProfileData/PGOCtxProfReaderWriterTest.cpp
index 8401e5b..5416672 100644
--- a/llvm/unittests/ProfileData/PGOCtxProfReaderWriterTest.cpp
+++ b/llvm/unittests/ProfileData/PGOCtxProfReaderWriterTest.cpp
@@ -103,7 +103,7 @@
       PGOCtxProfileWriter Writer(Out);
       Writer.startContextSection();
       for (auto &[_, R] : roots())
-        Writer.writeContextual(*R);
+        Writer.writeContextual(*R, 1);
       Writer.endContextSection();
     }
   }
@@ -155,7 +155,7 @@
     {
       PGOCtxProfileWriter Writer(Out);
       Writer.startContextSection();
-      Writer.writeContextual(*R);
+      Writer.writeContextual(*R, 2);
       Writer.endContextSection();
     }
   }
@@ -181,7 +181,7 @@
     {
       PGOCtxProfileWriter Writer(Out);
       Writer.startContextSection();
-      Writer.writeContextual(*R);
+      Writer.writeContextual(*R, 42);
       Writer.endContextSection();
     }
   }
@@ -208,7 +208,7 @@
       PGOCtxProfileWriter Writer(Out, /*VersionOverride=*/std::nullopt,
                                  /*IncludeEmpty=*/true);
       Writer.startContextSection();
-      Writer.writeContextual(*R);
+      Writer.writeContextual(*R, 8);
       Writer.endContextSection();
     }
   }
@@ -293,8 +293,8 @@
       PGOCtxProfileWriter Writer(Out, /*VersionOverride=*/std::nullopt,
                                  /*IncludeEmpty=*/true);
       Writer.startContextSection();
-      Writer.writeContextual(*createNode(1, 1, 1));
-      Writer.writeContextual(*createNode(1, 1, 1));
+      Writer.writeContextual(*createNode(1, 1, 1), 1);
+      Writer.writeContextual(*createNode(1, 1, 1), 1);
       Writer.endContextSection();
     }
   }
@@ -322,7 +322,7 @@
       R->subContexts()[0] = L2;
       PGOCtxProfileWriter Writer(Out);
       Writer.startContextSection();
-      Writer.writeContextual(*R);
+      Writer.writeContextual(*R, 1);
       Writer.endContextSection();
     }
   }
diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index 4bb521d..b642f37 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -511,6 +511,7 @@
     [
     {
       "Guid": 1000,
+      "TotalRootEntryCount": 1,
       "Counters": [1],
       "Callsites": [
         [{ "Guid": 1001,
@@ -525,6 +526,7 @@
     },
     {
       "Guid": 1005,
+      "TotalRootEntryCount": 1,
       "Counters": [2],
       "Callsites": [
         [{ "Guid": 1000,
@@ -575,6 +577,7 @@
   const char *Expected = R"yaml(
 Contexts:
   - Guid:            1000
+    TotalRootEntryCount: 1
     Counters:        [ 1, 11, 22 ]
     Callsites:
       - - Guid:            1001
@@ -587,6 +590,7 @@
             - - Guid:            1004
                 Counters:        [ 13 ]
   - Guid:            1005
+    TotalRootEntryCount: 1
     Counters:        [ 2 ]
     Callsites:
       - - Guid:            1000