| //===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Flattens the contextual profile and lowers it to MD_prof. |
| // This should happen after all IPO (which is assumed to have maintained the |
| // contextual profile) happened. Flattening consists of summing the values at |
| // the same index of the counters belonging to all the contexts of a function. |
| // The lowering consists of materializing the counter values to function |
| // entrypoint counts and branch probabilities. |
| // |
| // This pass also removes contextual instrumentation, which has been kept around |
| // to facilitate its functionality. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/ScopeExit.h" |
| #include "llvm/Analysis/CtxProfAnalysis.h" |
| #include "llvm/Analysis/ProfileSummaryInfo.h" |
| #include "llvm/IR/Analysis.h" |
| #include "llvm/IR/CFG.h" |
| #include "llvm/IR/Dominators.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/IntrinsicInst.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/IR/PassManager.h" |
| #include "llvm/IR/ProfileSummary.h" |
| #include "llvm/ProfileData/ProfileCommon.h" |
| #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" |
| #include "llvm/Transforms/Scalar/DCE.h" |
| #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| #include <deque> |
| #include <functional> |
| |
| using namespace llvm; |
| |
| #define DEBUG_TYPE "ctx_prof_flatten" |
| |
| namespace { |
| |
| class ProfileAnnotator final { |
| class BBInfo; |
| struct EdgeInfo { |
| BBInfo *const Src; |
| BBInfo *const Dest; |
| std::optional<uint64_t> Count; |
| |
| explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {} |
| }; |
| |
| class BBInfo { |
| std::optional<uint64_t> Count; |
| // OutEdges is dimensioned to match the number of terminator operands. |
| // Entries in the vector match the index in the terminator operand list. In |
| // some cases - see `shouldExcludeEdge` and its implementation - an entry |
| // will be nullptr. |
| // InEdges doesn't have the above constraint. |
| SmallVector<EdgeInfo *> OutEdges; |
| SmallVector<EdgeInfo *> InEdges; |
| size_t UnknownCountOutEdges = 0; |
| size_t UnknownCountInEdges = 0; |
| |
| // Pass AssumeAllKnown when we try to propagate counts from edges to BBs - |
| // because all the edge counters must be known. |
| // Return std::nullopt if there were no edges to sum. The user can decide |
| // how to interpret that. |
| std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges, |
| bool AssumeAllKnown) const { |
| std::optional<uint64_t> Sum; |
| for (const auto *E : Edges) { |
| // `Edges` may be `OutEdges`, case in which `E` could be nullptr. |
| if (E) { |
| if (!Sum.has_value()) |
| Sum = 0; |
| *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(0U)); |
| } |
| } |
| return Sum; |
| } |
| |
| bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) { |
| assert(!Count.has_value()); |
| Count = getEdgeSum(Edges, true); |
| return Count.has_value(); |
| } |
| |
| void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) { |
| uint64_t KnownSum = getEdgeSum(Edges, false).value_or(0U); |
| uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U; |
| EdgeInfo *E = nullptr; |
| for (auto *I : Edges) |
| if (I && !I->Count.has_value()) { |
| E = I; |
| #ifdef NDEBUG |
| break; |
| #else |
| assert((!E || E == I) && |
| "Expected exactly one edge to have an unknown count, " |
| "found a second one"); |
| continue; |
| #endif |
| } |
| assert(E && "Expected exactly one edge to have an unknown count"); |
| assert(!E->Count.has_value()); |
| E->Count = EdgeVal; |
| assert(E->Src->UnknownCountOutEdges > 0); |
| assert(E->Dest->UnknownCountInEdges > 0); |
| --E->Src->UnknownCountOutEdges; |
| --E->Dest->UnknownCountInEdges; |
| } |
| |
| public: |
| BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count) |
| : Count(Count) { |
| // For in edges, we just want to pre-allocate enough space, since we know |
| // it at this stage. For out edges, we will insert edges at the indices |
| // corresponding to positions in this BB's terminator instruction, so we |
| // construct a default (nullptr values)-initialized vector. A nullptr edge |
| // corresponds to those that are excluded (see shouldExcludeEdge). |
| InEdges.reserve(NumInEdges); |
| OutEdges.resize(NumOutEdges); |
| } |
| |
| bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) { |
| if (!UnknownCountOutEdges) { |
| return computeCountFrom(OutEdges); |
| } |
| return false; |
| } |
| |
| bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) { |
| if (!UnknownCountInEdges) { |
| return computeCountFrom(InEdges); |
| } |
| return false; |
| } |
| |
| void addInEdge(EdgeInfo &Info) { |
| InEdges.push_back(&Info); |
| ++UnknownCountInEdges; |
| } |
| |
| // For the out edges, we care about the position we place them in, which is |
| // the position in terminator instruction's list (at construction). Later, |
| // we build branch_weights metadata with edge frequency values matching |
| // these positions. |
| void addOutEdge(size_t Index, EdgeInfo &Info) { |
| OutEdges[Index] = &Info; |
| ++UnknownCountOutEdges; |
| } |
| |
| bool hasCount() const { return Count.has_value(); } |
| |
| uint64_t getCount() const { return *Count; } |
| |
| bool trySetSingleUnknownInEdgeCount() { |
| if (UnknownCountInEdges == 1) { |
| setSingleUnknownEdgeCount(InEdges); |
| return true; |
| } |
| return false; |
| } |
| |
| bool trySetSingleUnknownOutEdgeCount() { |
| if (UnknownCountOutEdges == 1) { |
| setSingleUnknownEdgeCount(OutEdges); |
| return true; |
| } |
| return false; |
| } |
| size_t getNumOutEdges() const { return OutEdges.size(); } |
| |
| uint64_t getEdgeCount(size_t Index) const { |
| if (auto *E = OutEdges[Index]) |
| return *E->Count; |
| return 0U; |
| } |
| }; |
| |
| Function &F; |
| const SmallVectorImpl<uint64_t> &Counters; |
| // To be accessed through getBBInfo() after construction. |
| std::map<const BasicBlock *, BBInfo> BBInfos; |
| std::vector<EdgeInfo> EdgeInfos; |
| |
| // This is an adaptation of PGOUseFunc::populateCounters. |
| // FIXME(mtrofin): look into factoring the code to share one implementation. |
| void propagateCounterValues(const SmallVectorImpl<uint64_t> &Counters) { |
| bool KeepGoing = true; |
| while (KeepGoing) { |
| KeepGoing = false; |
| for (const auto &BB : F) { |
| auto &Info = getBBInfo(BB); |
| if (!Info.hasCount()) |
| KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) || |
| Info.tryTakeCountFromKnownInEdges(BB); |
| if (Info.hasCount()) { |
| KeepGoing |= Info.trySetSingleUnknownOutEdgeCount(); |
| KeepGoing |= Info.trySetSingleUnknownInEdgeCount(); |
| } |
| } |
| } |
| } |
| // The only criteria for exclusion is faux suspend -> exit edges in presplit |
| // coroutines. The API serves for readability, currently. |
| bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const { |
| return llvm::isPresplitCoroSuspendExitEdge(Src, Dest); |
| } |
| |
| BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(&BB)->second; } |
| |
| const BBInfo &getBBInfo(const BasicBlock &BB) const { |
| return BBInfos.find(&BB)->second; |
| } |
| |
| // validation function after we propagate the counters: all BBs and edges' |
| // counters must have a value. |
| bool allCountersAreAssigned() const { |
| for (const auto &BBInfo : BBInfos) |
| if (!BBInfo.second.hasCount()) |
| return false; |
| for (const auto &EdgeInfo : EdgeInfos) |
| if (!EdgeInfo.Count.has_value()) |
| return false; |
| return true; |
| } |
| |
| /// Check that all paths from the entry basic block that use edges with |
| /// non-zero counts arrive at a basic block with no successors (i.e. "exit") |
| bool allTakenPathsExit() const { |
| std::deque<const BasicBlock *> Worklist; |
| DenseSet<const BasicBlock *> Visited; |
| Worklist.push_back(&F.getEntryBlock()); |
| bool HitExit = false; |
| while (!Worklist.empty()) { |
| const auto *BB = Worklist.front(); |
| Worklist.pop_front(); |
| if (!Visited.insert(BB).second) |
| continue; |
| if (succ_size(BB) == 0) { |
| if (isa<UnreachableInst>(BB->getTerminator())) |
| return false; |
| HitExit = true; |
| continue; |
| } |
| if (succ_size(BB) == 1) { |
| Worklist.push_back(BB->getUniqueSuccessor()); |
| continue; |
| } |
| const auto &BBInfo = getBBInfo(*BB); |
| bool HasAWayOut = false; |
| for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) { |
| const auto *Succ = BB->getTerminator()->getSuccessor(I); |
| if (!shouldExcludeEdge(*BB, *Succ)) { |
| if (BBInfo.getEdgeCount(I) > 0) { |
| HasAWayOut = true; |
| Worklist.push_back(Succ); |
| } |
| } |
| } |
| if (!HasAWayOut) |
| return false; |
| } |
| return HitExit; |
| } |
| |
| bool allNonColdSelectsHaveProfile() const { |
| for (const auto &BB : F) { |
| if (getBBInfo(BB).getCount() > 0) { |
| for (const auto &I : BB) { |
| if (const auto *SI = dyn_cast<SelectInst>(&I)) { |
| if (!SI->getMetadata(LLVMContext::MD_prof)) { |
| return false; |
| } |
| } |
| } |
| } |
| } |
| return true; |
| } |
| |
| public: |
| ProfileAnnotator(Function &F, const SmallVectorImpl<uint64_t> &Counters) |
| : F(F), Counters(Counters) { |
| assert(!F.isDeclaration()); |
| assert(!Counters.empty()); |
| size_t NrEdges = 0; |
| for (const auto &BB : F) { |
| std::optional<uint64_t> Count; |
| if (auto *Ins = CtxProfAnalysis::getBBInstrumentation( |
| const_cast<BasicBlock &>(BB))) { |
| auto Index = Ins->getIndex()->getZExtValue(); |
| assert(Index < Counters.size() && |
| "The index must be inside the counters vector by construction - " |
| "tripping this assertion indicates a bug in how the contextual " |
| "profile is managed by IPO transforms"); |
| (void)Index; |
| Count = Counters[Ins->getIndex()->getZExtValue()]; |
| } else if (isa<UnreachableInst>(BB.getTerminator())) { |
| // The program presumably didn't crash. |
| Count = 0; |
| } |
| auto [It, Ins] = |
| BBInfos.insert({&BB, {pred_size(&BB), succ_size(&BB), Count}}); |
| (void)Ins; |
| assert(Ins && "We iterate through the function's BBs, no reason to " |
| "insert one more than once"); |
| NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) { |
| return !shouldExcludeEdge(BB, *Succ); |
| }); |
| } |
| // Pre-allocate the vector, we want references to its contents to be stable. |
| EdgeInfos.reserve(NrEdges); |
| for (const auto &BB : F) { |
| auto &Info = getBBInfo(BB); |
| for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) { |
| const auto *Succ = BB.getTerminator()->getSuccessor(I); |
| if (!shouldExcludeEdge(BB, *Succ)) { |
| auto &EI = EdgeInfos.emplace_back(getBBInfo(BB), getBBInfo(*Succ)); |
| Info.addOutEdge(I, EI); |
| getBBInfo(*Succ).addInEdge(EI); |
| } |
| } |
| } |
| assert(EdgeInfos.capacity() == NrEdges && |
| "The capacity of EdgeInfos should have stayed unchanged it was " |
| "populated, because we need pointers to its contents to be stable"); |
| } |
| |
| void setProfileForSelectInstructions(BasicBlock &BB, const BBInfo &BBInfo) { |
| if (BBInfo.getCount() == 0) |
| return; |
| |
| for (auto &I : BB) { |
| if (auto *SI = dyn_cast<SelectInst>(&I)) { |
| if (auto *Step = CtxProfAnalysis::getSelectInstrumentation(*SI)) { |
| auto Index = Step->getIndex()->getZExtValue(); |
| assert(Index < Counters.size() && |
| "The index of the step instruction must be inside the " |
| "counters vector by " |
| "construction - tripping this assertion indicates a bug in " |
| "how the contextual profile is managed by IPO transforms"); |
| auto TotalCount = BBInfo.getCount(); |
| auto TrueCount = Counters[Index]; |
| auto FalseCount = |
| (TotalCount > TrueCount ? TotalCount - TrueCount : 0U); |
| setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount}, |
| std::max(TrueCount, FalseCount)); |
| } |
| } |
| } |
| } |
| |
| /// Assign branch weights and function entry count. Also update the PSI |
| /// builder. |
| void assignProfileData() { |
| assert(!Counters.empty()); |
| propagateCounterValues(Counters); |
| F.setEntryCount(Counters[0]); |
| |
| for (auto &BB : F) { |
| const auto &BBInfo = getBBInfo(BB); |
| setProfileForSelectInstructions(BB, BBInfo); |
| if (succ_size(&BB) < 2) |
| continue; |
| auto *Term = BB.getTerminator(); |
| SmallVector<uint64_t, 2> EdgeCounts(Term->getNumSuccessors(), 0); |
| uint64_t MaxCount = 0; |
| |
| for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size; |
| ++SuccIdx) { |
| uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx); |
| if (EdgeCount > MaxCount) |
| MaxCount = EdgeCount; |
| EdgeCounts[SuccIdx] = EdgeCount; |
| } |
| |
| if (MaxCount != 0) |
| setProfMetadata(F.getParent(), Term, EdgeCounts, MaxCount); |
| } |
| assert(allCountersAreAssigned() && |
| "[ctx-prof] Expected all counters have been assigned."); |
| assert(allTakenPathsExit() && |
| "[ctx-prof] Encountered a BB with more than one successor, where " |
| "all outgoing edges have a 0 count. This occurs in non-exiting " |
| "functions (message pumps, usually) which are not supported in the " |
| "contextual profiling case"); |
| assert(allNonColdSelectsHaveProfile() && |
| "[ctx-prof] All non-cold select instructions were expected to have " |
| "a profile."); |
| } |
| }; |
| |
| [[maybe_unused]] bool areAllBBsReachable(const Function &F, |
| FunctionAnalysisManager &FAM) { |
| auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F)); |
| return llvm::all_of( |
| F, [&](const BasicBlock &BB) { return DT.isReachableFromEntry(&BB); }); |
| } |
| |
| void clearColdFunctionProfile(Function &F) { |
| for (auto &BB : F) |
| BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr); |
| F.setEntryCount(0U); |
| } |
| |
| void removeInstrumentation(Function &F) { |
| for (auto &BB : F) |
| for (auto &I : llvm::make_early_inc_range(BB)) |
| if (isa<InstrProfCntrInstBase>(I)) |
| I.eraseFromParent(); |
| } |
| |
| void annotateIndirectCall( |
| Module &M, CallBase &CB, |
| const DenseMap<uint32_t, FlatIndirectTargets> &FlatProf, |
| const InstrProfCallsite &Ins) { |
| auto Idx = Ins.getIndex()->getZExtValue(); |
| auto FIt = FlatProf.find(Idx); |
| if (FIt == FlatProf.end()) |
| return; |
| const auto &Targets = FIt->second; |
| SmallVector<InstrProfValueData, 2> Data; |
| uint64_t Sum = 0; |
| for (auto &[Guid, Count] : Targets) { |
| Data.push_back({/*.Value=*/Guid, /*.Count=*/Count}); |
| Sum += Count; |
| } |
| |
| llvm::sort(Data, |
| [](const InstrProfValueData &A, const InstrProfValueData &B) { |
| return A.Count > B.Count; |
| }); |
| llvm::annotateValueSite(M, CB, Data, Sum, |
| InstrProfValueKind::IPVK_IndirectCallTarget, |
| Data.size()); |
| LLVM_DEBUG(dbgs() << "[ctxprof] flat indirect call prof: " << CB |
| << CB.getMetadata(LLVMContext::MD_prof) << "\n"); |
| } |
| |
| // We normally return a "Changed" bool, but the calling pass' run assumes |
| // something will change - some profile will be added - so this won't add much |
| // by returning false when applicable. |
| void annotateIndirectCalls(Module &M, const CtxProfAnalysis::Result &CtxProf) { |
| const auto FlatIndCalls = CtxProf.flattenVirtCalls(); |
| for (auto &F : M) { |
| if (F.isDeclaration()) |
| continue; |
| auto FlatProfIter = FlatIndCalls.find(AssignGUIDPass::getGUID(F)); |
| if (FlatProfIter == FlatIndCalls.end()) |
| continue; |
| const auto &FlatProf = FlatProfIter->second; |
| for (auto &BB : F) { |
| for (auto &I : BB) { |
| auto *CB = dyn_cast<CallBase>(&I); |
| if (!CB || !CB->isIndirectCall()) |
| continue; |
| if (auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB)) |
| annotateIndirectCall(M, *CB, FlatProf, *Ins); |
| } |
| } |
| } |
| } |
| |
| } // namespace |
| |
| PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M, |
| ModuleAnalysisManager &MAM) { |
| // Ensure in all cases the instrumentation is removed: if this module had no |
| // roots, the contextual profile would evaluate to false, but there would |
| // still be instrumentation. |
| // Note: in such cases we leave as-is any other profile info (if present - |
| // e.g. synthetic weights, etc) because it wouldn't interfere with the |
| // contextual - based one (which would be in other modules) |
| auto OnExit = llvm::make_scope_exit([&]() { |
| if (IsPreThinlink) |
| return; |
| for (auto &F : M) |
| removeInstrumentation(F); |
| }); |
| auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M); |
| // post-thinlink, we only reprocess for the module(s) containing the |
| // contextual tree. For everything else, OnExit will just clean the |
| // instrumentation. |
| if (!IsPreThinlink && !CtxProf.isInSpecializedModule()) |
| return PreservedAnalyses::none(); |
| |
| if (IsPreThinlink) |
| annotateIndirectCalls(M, CtxProf); |
| const auto FlattenedProfile = CtxProf.flatten(); |
| |
| for (auto &F : M) { |
| if (F.isDeclaration()) |
| continue; |
| |
| assert(areAllBBsReachable( |
| F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M) |
| .getManager()) && |
| "Function has unreacheable basic blocks. The expectation was that " |
| "DCE was run before."); |
| |
| auto It = FlattenedProfile.find(AssignGUIDPass::getGUID(F)); |
| // If this function didn't appear in the contextual profile, it's cold. |
| if (It == FlattenedProfile.end()) |
| clearColdFunctionProfile(F); |
| else { |
| ProfileAnnotator S(F, It->second); |
| S.assignProfileData(); |
| } |
| } |
| InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs); |
| // use here the flat profiles just so the importer doesn't complain about |
| // how different the PSIs are between the module with the roots and the |
| // various modules it imports. |
| for (auto &C : FlattenedProfile) { |
| PB.addEntryCount(C.second[0]); |
| for (auto V : llvm::drop_begin(C.second)) |
| PB.addInternalCount(V); |
| } |
| |
| M.setProfileSummary(PB.getSummary()->getMD(M.getContext()), |
| ProfileSummary::Kind::PSK_Instr); |
| PreservedAnalyses PA; |
| PA.abandon<ProfileSummaryAnalysis>(); |
| MAM.invalidate(M, PA); |
| auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M); |
| PSI.refresh(PB.getSummary()); |
| return PreservedAnalyses::none(); |
| } |