Add setBranchWeigths convenience function. NFC (#72446)

Add `setBranchWeights` convenience function to ProfDataUtils.h and use
it where appropriate.
This commit is contained in:
Matthias Braun 2023-11-16 10:55:19 -08:00 committed by GitHub
parent 186db1bcb0
commit cb4627d150
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 48 additions and 58 deletions

View File

@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
/// metadata was found.
bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
/// Create a new `branch_weights` metadata node and add or overwrite
/// a `prof` metadata reference to instruction `I`.
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
} // namespace llvm
#endif

View File

@ -17,6 +17,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
MDBuilder MDB(I.getContext());
MDNode *BranchWeights = MDB.createBranchWeights(Weights);
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
} // namespace llvm

View File

@ -56,6 +56,7 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/PseudoProbe.h"
#include "llvm/IR/ValueSymbolTable.h"
#include "llvm/ProfileData/InstrProf.h"
@ -1710,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
else if (OverwriteExistingWeights)
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else if (!isa<IntrinsicInst>(&I)) {
I.setMetadata(LLVMContext::MD_prof,
MDB.createBranchWeights(
{static_cast<uint32_t>(BlockWeights[BB])}));
setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
}
}
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@ -1720,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
// clear it for cold code.
for (auto &I : *BB) {
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
if (cast<CallBase>(I).isIndirectCall())
if (cast<CallBase>(I).isIndirectCall()) {
I.setMetadata(LLVMContext::MD_prof, nullptr);
else
I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
} else {
setBranchWeights(I, {uint32_t(0)});
}
}
}
}
@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
if (MaxWeight > 0 &&
(!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
setBranchWeights(*TI, Weights);
ORE->emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
<< "most popular destination for conditional branches at "

View File

@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
static_cast<uint32_t>(CHRBranchBias.scale(1000)),
static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
};
MDBuilder MDB(F.getContext());
MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
setBranchWeights(*MergedBR, Weights);
CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
<< "\n");
}

View File

@ -26,6 +26,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Value.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/Casting.h"
@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
if (AttachProfToDirectCall) {
MDBuilder MDB(NewInst.getContext());
NewInst.setMetadata(
LLVMContext::MD_prof,
MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
}
using namespace ore;

View File

@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
// If A is uncovered, set weight=1.
// This setup will allow BFI to give nonzero profile counts to only covered
// blocks.
SmallVector<unsigned, 4> Weights;
SmallVector<uint32_t, 4> Weights;
for (auto *Succ : successors(&BB))
Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
if (Weights.size() >= 2)
BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
MDB.createBranchWeights(Weights));
llvm::setBranchWeights(*BB.getTerminator(), Weights);
}
unsigned NumCorruptCoverage = 0;
@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
void llvm::setProfMetadata(Module *M, Instruction *TI,
ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
MDBuilder MDB(M->getContext());
assert(MaxCount > 0 && "Bad max count");
uint64_t Scale = calculateCountScale(MaxCount);
SmallVector<unsigned, 4> Weights;
@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
setBranchWeights(*TI, Weights);
if (EmitBranchProbability) {
std::string BrCondStr = getBranchCondString(TI);
if (BrCondStr.empty())

View File

@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
if (BP >= BranchProbability(50, 100))
continue;
SmallVector<uint32_t, 2> Weights;
uint32_t Weights[2];
if (PredBr->getSuccessor(0) == PredOutEdge.second) {
Weights.push_back(BP.getNumerator());
Weights.push_back(BP.getCompl().getNumerator());
Weights[0] = BP.getNumerator();
Weights[1] = BP.getCompl().getNumerator();
} else {
Weights.push_back(BP.getCompl().getNumerator());
Weights.push_back(BP.getNumerator());
Weights[0] = BP.getCompl().getNumerator();
Weights[1] = BP.getNumerator();
}
PredBr->setMetadata(LLVMContext::MD_prof,
MDBuilder(PredBr->getParent()->getContext())
.createBranchWeights(Weights));
setBranchWeights(*PredBr, Weights);
}
}
@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
Weights.push_back(Prob.getNumerator());
auto TI = BB->getTerminator();
TI->setMetadata(
LLVMContext::MD_prof,
MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
setBranchWeights(*TI, Weights);
}
}

View File

@ -20,6 +20,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/MisExpect.h"
@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
SI.setCondition(ArgValue);
SI.setMetadata(LLVMContext::MD_prof,
MDBuilder(CI->getContext()).createBranchWeights(Weights));
setBranchWeights(SI, Weights);
return true;
}

View File

@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// Remove weight for this case.
std::swap(Weights[Idx + 1], Weights.back());
Weights.pop_back();
SI->setMetadata(LLVMContext::MD_prof,
MDBuilder(BB->getContext()).
createBranchWeights(Weights));
setBranchWeights(*SI, Weights);
}
// Remove this entry.
BasicBlock *ParentBB = SI->getParent();

View File

@ -631,9 +631,7 @@ struct WeightInfo {
/// To avoid dealing with division rounding we can just multiple both part
/// of weights to E and use weight as (F - I * E, E).
static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
MDBuilder MDB(Term->getContext());
Term->setMetadata(LLVMContext::MD_prof,
MDB.createBranchWeights(Info.Weights));
setBranchWeights(*Term, Info.Weights);
for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
if (SubWeight != 0)
// Don't set the probability of taking the edge from latch to loop header
@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
}
}
/// Update the weights of original exiting block after peeling off all
/// iterations.
static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
MDBuilder MDB(Term->getContext());
Term->setMetadata(LLVMContext::MD_prof,
MDB.createBranchWeights(Info.Weights));
}
/// Clones the body of the loop L, putting it between \p InsertTop and \p
/// InsertBot.
/// \param IterNumber The serial number of the iteration currently being
@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
}
for (const auto &[Term, Info] : Weights)
fixupBranchWeights(Term, Info);
for (const auto &[Term, Info] : Weights) {
setBranchWeights(*Term, Info.Weights);
}
// Update Metadata for count of peeled off iterations.
unsigned AlreadyPeeled = 0;

View File

@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
LoopBackWeight = 0;
}
MDBuilder MDB(LoopBI.getContext());
MDNode *LoopWeightMD =
MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
SuccsSwapped ? ExitWeight1 : LoopBackWeight);
LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
const uint32_t LoopBIWeights[] = {
SuccsSwapped ? LoopBackWeight : ExitWeight1,
SuccsSwapped ? ExitWeight1 : LoopBackWeight,
};
setBranchWeights(LoopBI, LoopBIWeights);
if (HasConditionalPreHeader) {
MDNode *PreHeaderWeightMD =
MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
SuccsSwapped ? ExitWeight0 : EnterWeight);
PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
const uint32_t PreHeaderBIWeights[] = {
SuccsSwapped ? EnterWeight : ExitWeight0,
SuccsSwapped ? ExitWeight0 : EnterWeight,
};
setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
}
}