Refactor synthetic profile count computation. NFC.

Summary:
Instead of using two separate callbacks to return the entry count and the
relative block frequency, use a single callback to return callsite
count. This would allow better supporting hybrid mode in the future as
the count of callsite need not always be derived from entry count (as in
sample PGO).

Reviewers: davidxl

Subscribers: mehdi_amini, steven_wu, dexonsmith, dang, llvm-commits

Differential Revision: https://reviews.llvm.org/D56464

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@350755 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Easwaran Raman 2019-01-09 20:10:27 +00:00
parent 849dd13009
commit 116e08026c
4 changed files with 49 additions and 44 deletions

View File

@ -36,16 +36,17 @@ public:
using EdgeRef = typename CGT::EdgeRef; using EdgeRef = typename CGT::EdgeRef;
using SccTy = std::vector<NodeRef>; using SccTy = std::vector<NodeRef>;
using GetRelBBFreqTy = function_ref<Optional<Scaled64>(EdgeRef)>; // Not all EdgeRef have information about the source of the edge. Hence
using GetCountTy = function_ref<uint64_t(NodeRef)>; // NodeRef corresponding to the source of the EdgeRef is explicitly passed.
using AddCountTy = function_ref<void(NodeRef, uint64_t)>; using GetProfCountTy = function_ref<Optional<Scaled64>(NodeRef, EdgeRef)>;
using AddCountTy = function_ref<void(NodeRef, Scaled64)>;
static void propagate(const CallGraphType &CG, GetRelBBFreqTy GetRelBBFreq, static void propagate(const CallGraphType &CG, GetProfCountTy GetProfCount,
GetCountTy GetCount, AddCountTy AddCount); AddCountTy AddCount);
private: private:
static void propagateFromSCC(const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, static void propagateFromSCC(const SccTy &SCC, GetProfCountTy GetProfCount,
GetCountTy GetCount, AddCountTy AddCount); AddCountTy AddCount);
}; };
} // namespace llvm } // namespace llvm

View File

@ -26,8 +26,7 @@ using namespace llvm;
// Given an SCC, propagate entry counts along the edge of the SCC nodes. // Given an SCC, propagate entry counts along the edge of the SCC nodes.
template <typename CallGraphType> template <typename CallGraphType>
void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, GetCountTy GetCount, const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) {
AddCountTy AddCount) {
DenseSet<NodeRef> SCCNodes; DenseSet<NodeRef> SCCNodes;
SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges; SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges;
@ -54,17 +53,13 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
// This ensures that the order of // This ensures that the order of
// traversal of nodes within the SCC doesn't affect the final result. // traversal of nodes within the SCC doesn't affect the final result.
DenseMap<NodeRef, uint64_t> AdditionalCounts; DenseMap<NodeRef, Scaled64> AdditionalCounts;
for (auto &E : SCCEdges) { for (auto &E : SCCEdges) {
auto OptRelFreq = GetRelBBFreq(E.second); auto OptProfCount = GetProfCount(E.first, E.second);
if (!OptRelFreq) if (!OptProfCount)
continue; continue;
Scaled64 RelFreq = OptRelFreq.getValue();
auto Caller = E.first;
auto Callee = CGT::edge_dest(E.second); auto Callee = CGT::edge_dest(E.second);
RelFreq *= Scaled64(GetCount(Caller), 0); AdditionalCounts[Callee] += OptProfCount.getValue();
uint64_t AdditionalCount = RelFreq.toInt<uint64_t>();
AdditionalCounts[Callee] += AdditionalCount;
} }
// Update the counts for the nodes in the SCC. // Update the counts for the nodes in the SCC.
@ -73,14 +68,11 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
// Now update the counts for nodes outside the SCC. // Now update the counts for nodes outside the SCC.
for (auto &E : NonSCCEdges) { for (auto &E : NonSCCEdges) {
auto OptRelFreq = GetRelBBFreq(E.second); auto OptProfCount = GetProfCount(E.first, E.second);
if (!OptRelFreq) if (!OptProfCount)
continue; continue;
Scaled64 RelFreq = OptRelFreq.getValue();
auto Caller = E.first;
auto Callee = CGT::edge_dest(E.second); auto Callee = CGT::edge_dest(E.second);
RelFreq *= Scaled64(GetCount(Caller), 0); AddCount(Callee, OptProfCount.getValue());
AddCount(Callee, RelFreq.toInt<uint64_t>());
} }
} }
@ -94,8 +86,7 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
template <typename CallGraphType> template <typename CallGraphType>
void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG, void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
GetRelBBFreqTy GetRelBBFreq, GetProfCountTy GetProfCount,
GetCountTy GetCount,
AddCountTy AddCount) { AddCountTy AddCount) {
std::vector<SccTy> SCCs; std::vector<SccTy> SCCs;
@ -107,7 +98,7 @@ void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
// The scc iterator returns the scc in bottom-up order, so reverse the SCCs // The scc iterator returns the scc in bottom-up order, so reverse the SCCs
// and call propagateFromSCC. // and call propagateFromSCC.
for (auto &SCC : reverse(SCCs)) for (auto &SCC : reverse(SCCs))
propagateFromSCC(SCC, GetRelBBFreq, GetCount, AddCount); propagateFromSCC(SCC, GetProfCount, AddCount);
} }
template class llvm::SyntheticCountsUtils<const CallGraph *>; template class llvm::SyntheticCountsUtils<const CallGraph *>;

View File

@ -60,21 +60,27 @@ void llvm::computeSyntheticCounts(ModuleSummaryIndex &Index) {
return UINT64_C(0); return UINT64_C(0);
} }
}; };
auto AddToEntryCount = [](ValueInfo V, uint64_t New) { auto AddToEntryCount = [](ValueInfo V, Scaled64 New) {
if (!V.getSummaryList().size()) if (!V.getSummaryList().size())
return; return;
for (auto &GVS : V.getSummaryList()) { for (auto &GVS : V.getSummaryList()) {
auto S = GVS.get()->getBaseObject(); auto S = GVS.get()->getBaseObject();
auto *F = cast<FunctionSummary>(S); auto *F = cast<FunctionSummary>(S);
F->setEntryCount(SaturatingAdd(F->entryCount(), New)); F->setEntryCount(
SaturatingAdd(F->entryCount(), New.template toInt<uint64_t>()));
} }
}; };
auto GetProfileCount = [&](ValueInfo V, FunctionSummary::EdgeTy &Edge) {
auto RelFreq = GetCallSiteRelFreq(Edge);
Scaled64 EC(GetEntryCount(V), 0);
return RelFreq * EC;
};
// After initializing the counts in initializeCounts above, the counts have to // After initializing the counts in initializeCounts above, the counts have to
// be propagated across the combined callgraph. // be propagated across the combined callgraph.
// SyntheticCountsUtils::propagate takes care of this propagation on any // SyntheticCountsUtils::propagate takes care of this propagation on any
// callgraph that specialized GraphTraits. // callgraph that specialized GraphTraits.
SyntheticCountsUtils<ModuleSummaryIndex *>::propagate( SyntheticCountsUtils<ModuleSummaryIndex *>::propagate(&Index, GetProfileCount,
&Index, GetCallSiteRelFreq, GetEntryCount, AddToEntryCount); AddToEntryCount);
Index.setHasSyntheticEntryCounts(); Index.setHasSyntheticEntryCounts();
} }

View File

@ -30,6 +30,7 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/SyntheticCountsUtils.h" #include "llvm/Analysis/SyntheticCountsUtils.h"
#include "llvm/IR/CallSite.h" #include "llvm/IR/CallSite.h"
#include "llvm/IR/Function.h" #include "llvm/IR/Function.h"
@ -98,13 +99,15 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M,
ModuleAnalysisManager &MAM) { ModuleAnalysisManager &MAM) {
FunctionAnalysisManager &FAM = FunctionAnalysisManager &FAM =
MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
DenseMap<Function *, uint64_t> Counts; DenseMap<Function *, Scaled64> Counts;
// Set initial entry counts. // Set initial entry counts.
initializeCounts(M, [&](Function *F, uint64_t Count) { Counts[F] = Count; }); initializeCounts(
M, [&](Function *F, uint64_t Count) { Counts[F] = Scaled64(Count, 0); });
// Compute the relative block frequency for a call edge. Use scaled numbers // Edge includes information about the source. Hence ignore the first
// and not integers since the relative block frequency could be less than 1. // parameter.
auto GetCallSiteRelFreq = [&](const CallGraphNode::CallRecord &Edge) { auto GetCallSiteProfCount = [&](const CallGraphNode *,
const CallGraphNode::CallRecord &Edge) {
Optional<Scaled64> Res = None; Optional<Scaled64> Res = None;
if (!Edge.first) if (!Edge.first)
return Res; return Res;
@ -112,29 +115,33 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M,
CallSite CS(cast<Instruction>(Edge.first)); CallSite CS(cast<Instruction>(Edge.first));
Function *Caller = CS.getCaller(); Function *Caller = CS.getCaller();
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*Caller); auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*Caller);
// Now compute the callsite count from relative frequency and
// entry count:
BasicBlock *CSBB = CS.getInstruction()->getParent(); BasicBlock *CSBB = CS.getInstruction()->getParent();
Scaled64 EntryFreq(BFI.getEntryFreq(), 0); Scaled64 EntryFreq(BFI.getEntryFreq(), 0);
Scaled64 BBFreq(BFI.getBlockFreq(CSBB).getFrequency(), 0); Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0);
BBFreq /= EntryFreq; BBCount /= EntryFreq;
return Optional<Scaled64>(BBFreq); BBCount *= Counts[Caller];
return Optional<Scaled64>(BBCount);
}; };
CallGraph CG(M); CallGraph CG(M);
// Propgate the entry counts on the callgraph. // Propgate the entry counts on the callgraph.
SyntheticCountsUtils<const CallGraph *>::propagate( SyntheticCountsUtils<const CallGraph *>::propagate(
&CG, GetCallSiteRelFreq, &CG, GetCallSiteProfCount, [&](const CallGraphNode *N, Scaled64 New) {
[&](const CallGraphNode *N) { return Counts[N->getFunction()]; },
[&](const CallGraphNode *N, uint64_t New) {
auto F = N->getFunction(); auto F = N->getFunction();
if (!F || F->isDeclaration()) if (!F || F->isDeclaration())
return; return;
Counts[F] += New; Counts[F] += New;
}); });
// Set the counts as metadata. // Set the counts as metadata.
for (auto Entry : Counts) for (auto Entry : Counts) {
Entry.first->setEntryCount( Entry.first->setEntryCount(ProfileCount(
ProfileCount(Entry.second, Function::PCT_Synthetic)); Entry.second.template toInt<uint64_t>(), Function::PCT_Synthetic));
}
return PreservedAnalyses::all(); return PreservedAnalyses::all();
} }