diff --git a/include/llvm/Analysis/SyntheticCountsUtils.h b/include/llvm/Analysis/SyntheticCountsUtils.h index 87f4a0100b3..db80bef001e 100644 --- a/include/llvm/Analysis/SyntheticCountsUtils.h +++ b/include/llvm/Analysis/SyntheticCountsUtils.h @@ -36,16 +36,17 @@ public: using EdgeRef = typename CGT::EdgeRef; using SccTy = std::vector; - using GetRelBBFreqTy = function_ref(EdgeRef)>; - using GetCountTy = function_ref; - using AddCountTy = function_ref; + // Not all EdgeRef have information about the source of the edge. Hence + // NodeRef corresponding to the source of the EdgeRef is explicitly passed. + using GetProfCountTy = function_ref(NodeRef, EdgeRef)>; + using AddCountTy = function_ref; - static void propagate(const CallGraphType &CG, GetRelBBFreqTy GetRelBBFreq, - GetCountTy GetCount, AddCountTy AddCount); + static void propagate(const CallGraphType &CG, GetProfCountTy GetProfCount, + AddCountTy AddCount); private: - static void propagateFromSCC(const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, - GetCountTy GetCount, AddCountTy AddCount); + static void propagateFromSCC(const SccTy &SCC, GetProfCountTy GetProfCount, + AddCountTy AddCount); }; } // namespace llvm diff --git a/lib/Analysis/SyntheticCountsUtils.cpp b/lib/Analysis/SyntheticCountsUtils.cpp index 386396bcff3..c2d7bb11a4c 100644 --- a/lib/Analysis/SyntheticCountsUtils.cpp +++ b/lib/Analysis/SyntheticCountsUtils.cpp @@ -26,8 +26,7 @@ using namespace llvm; // Given an SCC, propagate entry counts along the edge of the SCC nodes. template void SyntheticCountsUtils::propagateFromSCC( - const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, GetCountTy GetCount, - AddCountTy AddCount) { + const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) { DenseSet SCCNodes; SmallVector, 8> SCCEdges, NonSCCEdges; @@ -54,17 +53,13 @@ void SyntheticCountsUtils::propagateFromSCC( // This ensures that the order of // traversal of nodes within the SCC doesn't affect the final result. - DenseMap AdditionalCounts; + DenseMap AdditionalCounts; for (auto &E : SCCEdges) { - auto OptRelFreq = GetRelBBFreq(E.second); - if (!OptRelFreq) + auto OptProfCount = GetProfCount(E.first, E.second); + if (!OptProfCount) continue; - Scaled64 RelFreq = OptRelFreq.getValue(); - auto Caller = E.first; auto Callee = CGT::edge_dest(E.second); - RelFreq *= Scaled64(GetCount(Caller), 0); - uint64_t AdditionalCount = RelFreq.toInt(); - AdditionalCounts[Callee] += AdditionalCount; + AdditionalCounts[Callee] += OptProfCount.getValue(); } // Update the counts for the nodes in the SCC. @@ -73,14 +68,11 @@ void SyntheticCountsUtils::propagateFromSCC( // Now update the counts for nodes outside the SCC. for (auto &E : NonSCCEdges) { - auto OptRelFreq = GetRelBBFreq(E.second); - if (!OptRelFreq) + auto OptProfCount = GetProfCount(E.first, E.second); + if (!OptProfCount) continue; - Scaled64 RelFreq = OptRelFreq.getValue(); - auto Caller = E.first; auto Callee = CGT::edge_dest(E.second); - RelFreq *= Scaled64(GetCount(Caller), 0); - AddCount(Callee, RelFreq.toInt()); + AddCount(Callee, OptProfCount.getValue()); } } @@ -94,8 +86,7 @@ void SyntheticCountsUtils::propagateFromSCC( template void SyntheticCountsUtils::propagate(const CallGraphType &CG, - GetRelBBFreqTy GetRelBBFreq, - GetCountTy GetCount, + GetProfCountTy GetProfCount, AddCountTy AddCount) { std::vector SCCs; @@ -107,7 +98,7 @@ void SyntheticCountsUtils::propagate(const CallGraphType &CG, // The scc iterator returns the scc in bottom-up order, so reverse the SCCs // and call propagateFromSCC. for (auto &SCC : reverse(SCCs)) - propagateFromSCC(SCC, GetRelBBFreq, GetCount, AddCount); + propagateFromSCC(SCC, GetProfCount, AddCount); } template class llvm::SyntheticCountsUtils; diff --git a/lib/LTO/SummaryBasedOptimizations.cpp b/lib/LTO/SummaryBasedOptimizations.cpp index 8b1abb78462..bcdd984daa5 100644 --- a/lib/LTO/SummaryBasedOptimizations.cpp +++ b/lib/LTO/SummaryBasedOptimizations.cpp @@ -60,21 +60,27 @@ void llvm::computeSyntheticCounts(ModuleSummaryIndex &Index) { return UINT64_C(0); } }; - auto AddToEntryCount = [](ValueInfo V, uint64_t New) { + auto AddToEntryCount = [](ValueInfo V, Scaled64 New) { if (!V.getSummaryList().size()) return; for (auto &GVS : V.getSummaryList()) { auto S = GVS.get()->getBaseObject(); auto *F = cast(S); - F->setEntryCount(SaturatingAdd(F->entryCount(), New)); + F->setEntryCount( + SaturatingAdd(F->entryCount(), New.template toInt())); } }; + auto GetProfileCount = [&](ValueInfo V, FunctionSummary::EdgeTy &Edge) { + auto RelFreq = GetCallSiteRelFreq(Edge); + Scaled64 EC(GetEntryCount(V), 0); + return RelFreq * EC; + }; // After initializing the counts in initializeCounts above, the counts have to // be propagated across the combined callgraph. // SyntheticCountsUtils::propagate takes care of this propagation on any // callgraph that specialized GraphTraits. - SyntheticCountsUtils::propagate( - &Index, GetCallSiteRelFreq, GetEntryCount, AddToEntryCount); + SyntheticCountsUtils::propagate(&Index, GetProfileCount, + AddToEntryCount); Index.setHasSyntheticEntryCounts(); } diff --git a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index 64837d4f5d6..ba4efb3ff60 100644 --- a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/SyntheticCountsUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Function.h" @@ -98,13 +99,15 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, ModuleAnalysisManager &MAM) { FunctionAnalysisManager &FAM = MAM.getResult(M).getManager(); - DenseMap Counts; + DenseMap Counts; // Set initial entry counts. - initializeCounts(M, [&](Function *F, uint64_t Count) { Counts[F] = Count; }); + initializeCounts( + M, [&](Function *F, uint64_t Count) { Counts[F] = Scaled64(Count, 0); }); - // Compute the relative block frequency for a call edge. Use scaled numbers - // and not integers since the relative block frequency could be less than 1. - auto GetCallSiteRelFreq = [&](const CallGraphNode::CallRecord &Edge) { + // Edge includes information about the source. Hence ignore the first + // parameter. + auto GetCallSiteProfCount = [&](const CallGraphNode *, + const CallGraphNode::CallRecord &Edge) { Optional Res = None; if (!Edge.first) return Res; @@ -112,29 +115,33 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, CallSite CS(cast(Edge.first)); Function *Caller = CS.getCaller(); auto &BFI = FAM.getResult(*Caller); + + // Now compute the callsite count from relative frequency and + // entry count: BasicBlock *CSBB = CS.getInstruction()->getParent(); Scaled64 EntryFreq(BFI.getEntryFreq(), 0); - Scaled64 BBFreq(BFI.getBlockFreq(CSBB).getFrequency(), 0); - BBFreq /= EntryFreq; - return Optional(BBFreq); + Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0); + BBCount /= EntryFreq; + BBCount *= Counts[Caller]; + return Optional(BBCount); }; CallGraph CG(M); // Propgate the entry counts on the callgraph. SyntheticCountsUtils::propagate( - &CG, GetCallSiteRelFreq, - [&](const CallGraphNode *N) { return Counts[N->getFunction()]; }, - [&](const CallGraphNode *N, uint64_t New) { + &CG, GetCallSiteProfCount, [&](const CallGraphNode *N, Scaled64 New) { auto F = N->getFunction(); if (!F || F->isDeclaration()) return; + Counts[F] += New; }); // Set the counts as metadata. - for (auto Entry : Counts) - Entry.first->setEntryCount( - ProfileCount(Entry.second, Function::PCT_Synthetic)); + for (auto Entry : Counts) { + Entry.first->setEntryCount(ProfileCount( + Entry.second.template toInt(), Function::PCT_Synthetic)); + } return PreservedAnalyses::all(); }