PGO: preserve branch-weight metadata when simplifying two branches with a common

destination.

Updated previous implementation to fix a case not covered:
// PBI: br i1 %x, TrueDest, BB
// BI:  br i1 %y, TrueDest, FalseDest
The other case was handled correctly.
// PBI: br i1 %x, BB, FalseDest
// BI:  br i1 %y, TrueDest, FalseDest

Also tried to use 64-bit arithmetic instead of APInt with scale to simplify the
computation. Let me know if you have other opinions about this.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@163954 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Manman Ren 2012-09-15 00:39:57 +00:00
parent 3cbd1786ac
commit 062986c2f0
2 changed files with 64 additions and 112 deletions

View File

@ -1658,7 +1658,7 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI,
/// parameters and return true, or returns false if no or invalid metadata was
/// found.
static bool ExtractBranchMetadata(BranchInst *BI,
APInt &ProbTrue, APInt &ProbFalse) {
uint64_t &ProbTrue, uint64_t &ProbFalse) {
assert(BI->isConditional() &&
"Looking for probabilities on unconditional branch?");
MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof);
@ -1666,35 +1666,11 @@ static bool ExtractBranchMetadata(BranchInst *BI,
ConstantInt *CITrue = dyn_cast<ConstantInt>(ProfileData->getOperand(1));
ConstantInt *CIFalse = dyn_cast<ConstantInt>(ProfileData->getOperand(2));
if (!CITrue || !CIFalse) return false;
ProbTrue = CITrue->getValue();
ProbFalse = CIFalse->getValue();
assert(ProbTrue.getBitWidth() == 32 && ProbFalse.getBitWidth() == 32 &&
"Branch probability metadata must be 32-bit integers");
ProbTrue = CITrue->getValue().getZExtValue();
ProbFalse = CIFalse->getValue().getZExtValue();
return true;
}
/// MultiplyAndLosePrecision - Multiplies A and B, then returns the result. In
/// the event of overflow, logically-shifts all four inputs right until the
/// multiply fits.
static APInt MultiplyAndLosePrecision(APInt &A, APInt &B, APInt &C, APInt &D,
unsigned &BitsLost) {
BitsLost = 0;
bool Overflow = false;
APInt Result = A.umul_ov(B, Overflow);
if (Overflow) {
APInt MaxB = APInt::getMaxValue(A.getBitWidth()).udiv(A);
do {
B = B.lshr(1);
++BitsLost;
} while (B.ugt(MaxB));
A = A.lshr(BitsLost);
C = C.lshr(BitsLost);
D = D.lshr(BitsLost);
Result = A * B;
}
return Result;
}
/// checkCSEInPredecessor - Return true if the given instruction is available
/// in its predecessor block. If yes, the instruction will be removed.
///
@ -1919,14 +1895,53 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) {
New, "or.cond"));
PBI->setCondition(NewCond);
uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight;
bool PredHasWeights = ExtractBranchMetadata(PBI, PredTrueWeight,
PredFalseWeight);
bool SuccHasWeights = ExtractBranchMetadata(BI, SuccTrueWeight,
SuccFalseWeight);
SmallVector<uint64_t, 8> NewWeights;
if (PBI->getSuccessor(0) == BB) {
if (PredHasWeights && SuccHasWeights) {
// PBI: br i1 %x, BB, FalseDest
// BI: br i1 %y, TrueDest, FalseDest
//TrueWeight is TrueWeight for PBI * TrueWeight for BI.
NewWeights.push_back(PredTrueWeight * SuccTrueWeight);
//FalseWeight is FalseWeight for PBI * TotalWeight for BI +
// TrueWeight for PBI * FalseWeight for BI.
// We assume that total weights of a BranchInst can fit into 32 bits.
// Therefore, we will not have overflow using 64-bit arithmetic.
NewWeights.push_back(PredFalseWeight * (SuccFalseWeight +
SuccTrueWeight) + PredTrueWeight * SuccFalseWeight);
}
AddPredecessorToBlock(TrueDest, PredBlock, BB);
PBI->setSuccessor(0, TrueDest);
}
if (PBI->getSuccessor(1) == BB) {
if (PredHasWeights && SuccHasWeights) {
// PBI: br i1 %x, TrueDest, BB
// BI: br i1 %y, TrueDest, FalseDest
//TrueWeight is TrueWeight for PBI * TotalWeight for BI +
// FalseWeight for PBI * TrueWeight for BI.
NewWeights.push_back(PredTrueWeight * (SuccFalseWeight +
SuccTrueWeight) + PredFalseWeight * SuccTrueWeight);
//FalseWeight is FalseWeight for PBI * FalseWeight for BI.
NewWeights.push_back(PredFalseWeight * SuccFalseWeight);
}
AddPredecessorToBlock(FalseDest, PredBlock, BB);
PBI->setSuccessor(1, FalseDest);
}
if (NewWeights.size() == 2) {
// Halve the weights if any of them cannot fit in an uint32_t
FitWeights(NewWeights);
SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(),NewWeights.end());
PBI->setMetadata(LLVMContext::MD_prof,
MDBuilder(BI->getContext()).
createBranchWeights(MDWeights));
} else
PBI->setMetadata(LLVMContext::MD_prof, NULL);
} else {
// Update PHI nodes in the common successors.
for (unsigned i = 0, e = PHIs.size(); i != e; ++i) {
@ -1981,90 +1996,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) {
// TODO: If BB is reachable from all paths through PredBlock, then we
// could replace PBI's branch probabilities with BI's.
// Merge probability data into PredBlock's branch.
APInt A, B, C, D;
if (PBI->isConditional() && BI->isConditional() &&
ExtractBranchMetadata(PBI, C, D) && ExtractBranchMetadata(BI, A, B)) {
// Given IR which does:
// bbA:
// br i1 %x, label %bbB, label %bbC
// bbB:
// br i1 %y, label %bbD, label %bbC
// Let's call the probability that we take the edge from %bbA to %bbB
// 'a', from %bbA to %bbC, 'b', from %bbB to %bbD 'c' and from %bbB to
// %bbC probability 'd'.
//
// We transform the IR into:
// bbA:
// br i1 %z, label %bbD, label %bbC
// where the probability of going to %bbD is (a*c) and going to bbC is
// (b+a*d).
//
// Probabilities aren't stored as ratios directly. Using branch weights,
// we get:
// (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.
// In the event of overflow, we want to drop the LSB of the input
// probabilities.
unsigned BitsLost;
// Ignore overflow result on ProbTrue.
APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost);
APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost);
if (BitsLost) {
ProbTrue = ProbTrue.lshr(BitsLost*2);
}
APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost);
if (BitsLost) {
ProbTrue = ProbTrue.lshr(BitsLost*2);
Tmp1 = Tmp1.lshr(BitsLost*2);
}
APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost);
if (BitsLost) {
ProbTrue = ProbTrue.lshr(BitsLost*2);
Tmp1 = Tmp1.lshr(BitsLost*2);
Tmp2 = Tmp2.lshr(BitsLost*2);
}
bool Overflow1 = false, Overflow2 = false;
APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1);
APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2);
if (Overflow1 || Overflow2) {
ProbTrue = ProbTrue.lshr(1);
Tmp1 = Tmp1.lshr(1);
Tmp2 = Tmp2.lshr(1);
Tmp3 = Tmp3.lshr(1);
Tmp4 = Tmp2 + Tmp3;
ProbFalse = Tmp4 + Tmp1;
}
// The sum of branch weights must fit in 32-bits.
if (ProbTrue.isNegative() && ProbFalse.isNegative()) {
ProbTrue = ProbTrue.lshr(1);
ProbFalse = ProbFalse.lshr(1);
}
if (ProbTrue != ProbFalse) {
// Normalize the result.
APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
ProbTrue = ProbTrue.udiv(GCD);
ProbFalse = ProbFalse.udiv(GCD);
MDBuilder MDB(BI->getContext());
MDNode *N = MDB.createBranchWeights(ProbTrue.getZExtValue(),
ProbFalse.getZExtValue());
PBI->setMetadata(LLVMContext::MD_prof, N);
} else {
PBI->setMetadata(LLVMContext::MD_prof, NULL);
}
} else {
PBI->setMetadata(LLVMContext::MD_prof, NULL);
}
// Copy any debug value intrinsics into the end of PredBlock.
for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
if (isa<DbgInfoIntrinsic>(*I))

View File

@ -154,6 +154,26 @@ sw.epilog:
ret void
}
;; This test is based on test1 but swapped the targets of the second branch.
define void @test1_swap(i1 %a, i1 %b) {
; CHECK: @test1_swap
entry:
br i1 %a, label %Y, label %X, !prof !0
; CHECK: br i1 %or.cond, label %Y, label %Z, !prof !4
X:
%c = or i1 %b, false
br i1 %c, label %Y, label %Z, !prof !1
Y:
call void @helper(i32 0)
ret void
Z:
call void @helper(i32 1)
ret void
}
!0 = metadata !{metadata !"branch_weights", i32 3, i32 5}
!1 = metadata !{metadata !"branch_weights", i32 1, i32 1}
!2 = metadata !{metadata !"branch_weights", i32 1, i32 2}
@ -165,4 +185,5 @@ sw.epilog:
; CHECK: !1 = metadata !{metadata !"branch_weights", i32 1, i32 5}
; CHECK: !2 = metadata !{metadata !"branch_weights", i32 7, i32 1, i32 2}
; CHECK: !3 = metadata !{metadata !"branch_weights", i32 49, i32 12, i32 24, i32 35}
; CHECK-NOT: !4
; CHECK: !4 = metadata !{metadata !"branch_weights", i32 11, i32 5}
; CHECK-NOT: !5