mirror of
https://github.com/RPCSX/llvm.git
synced 2024-11-25 20:59:51 +00:00
PGO: preserve branch-weight metadata when simplifying two branches with a common
destination. Updated previous implementation to fix a case not covered: // PBI: br i1 %x, TrueDest, BB // BI: br i1 %y, TrueDest, FalseDest The other case was handled correctly. // PBI: br i1 %x, BB, FalseDest // BI: br i1 %y, TrueDest, FalseDest Also tried to use 64-bit arithmetic instead of APInt with scale to simplify the computation. Let me know if you have other opinions about this. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@163954 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
parent
3cbd1786ac
commit
062986c2f0
@ -1658,7 +1658,7 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI,
|
||||
/// parameters and return true, or returns false if no or invalid metadata was
|
||||
/// found.
|
||||
static bool ExtractBranchMetadata(BranchInst *BI,
|
||||
APInt &ProbTrue, APInt &ProbFalse) {
|
||||
uint64_t &ProbTrue, uint64_t &ProbFalse) {
|
||||
assert(BI->isConditional() &&
|
||||
"Looking for probabilities on unconditional branch?");
|
||||
MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof);
|
||||
@ -1666,35 +1666,11 @@ static bool ExtractBranchMetadata(BranchInst *BI,
|
||||
ConstantInt *CITrue = dyn_cast<ConstantInt>(ProfileData->getOperand(1));
|
||||
ConstantInt *CIFalse = dyn_cast<ConstantInt>(ProfileData->getOperand(2));
|
||||
if (!CITrue || !CIFalse) return false;
|
||||
ProbTrue = CITrue->getValue();
|
||||
ProbFalse = CIFalse->getValue();
|
||||
assert(ProbTrue.getBitWidth() == 32 && ProbFalse.getBitWidth() == 32 &&
|
||||
"Branch probability metadata must be 32-bit integers");
|
||||
ProbTrue = CITrue->getValue().getZExtValue();
|
||||
ProbFalse = CIFalse->getValue().getZExtValue();
|
||||
return true;
|
||||
}
|
||||
|
||||
/// MultiplyAndLosePrecision - Multiplies A and B, then returns the result. In
|
||||
/// the event of overflow, logically-shifts all four inputs right until the
|
||||
/// multiply fits.
|
||||
static APInt MultiplyAndLosePrecision(APInt &A, APInt &B, APInt &C, APInt &D,
|
||||
unsigned &BitsLost) {
|
||||
BitsLost = 0;
|
||||
bool Overflow = false;
|
||||
APInt Result = A.umul_ov(B, Overflow);
|
||||
if (Overflow) {
|
||||
APInt MaxB = APInt::getMaxValue(A.getBitWidth()).udiv(A);
|
||||
do {
|
||||
B = B.lshr(1);
|
||||
++BitsLost;
|
||||
} while (B.ugt(MaxB));
|
||||
A = A.lshr(BitsLost);
|
||||
C = C.lshr(BitsLost);
|
||||
D = D.lshr(BitsLost);
|
||||
Result = A * B;
|
||||
}
|
||||
return Result;
|
||||
}
|
||||
|
||||
/// checkCSEInPredecessor - Return true if the given instruction is available
|
||||
/// in its predecessor block. If yes, the instruction will be removed.
|
||||
///
|
||||
@ -1919,14 +1895,53 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) {
|
||||
New, "or.cond"));
|
||||
PBI->setCondition(NewCond);
|
||||
|
||||
uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight;
|
||||
bool PredHasWeights = ExtractBranchMetadata(PBI, PredTrueWeight,
|
||||
PredFalseWeight);
|
||||
bool SuccHasWeights = ExtractBranchMetadata(BI, SuccTrueWeight,
|
||||
SuccFalseWeight);
|
||||
SmallVector<uint64_t, 8> NewWeights;
|
||||
|
||||
if (PBI->getSuccessor(0) == BB) {
|
||||
if (PredHasWeights && SuccHasWeights) {
|
||||
// PBI: br i1 %x, BB, FalseDest
|
||||
// BI: br i1 %y, TrueDest, FalseDest
|
||||
//TrueWeight is TrueWeight for PBI * TrueWeight for BI.
|
||||
NewWeights.push_back(PredTrueWeight * SuccTrueWeight);
|
||||
//FalseWeight is FalseWeight for PBI * TotalWeight for BI +
|
||||
// TrueWeight for PBI * FalseWeight for BI.
|
||||
// We assume that total weights of a BranchInst can fit into 32 bits.
|
||||
// Therefore, we will not have overflow using 64-bit arithmetic.
|
||||
NewWeights.push_back(PredFalseWeight * (SuccFalseWeight +
|
||||
SuccTrueWeight) + PredTrueWeight * SuccFalseWeight);
|
||||
}
|
||||
AddPredecessorToBlock(TrueDest, PredBlock, BB);
|
||||
PBI->setSuccessor(0, TrueDest);
|
||||
}
|
||||
if (PBI->getSuccessor(1) == BB) {
|
||||
if (PredHasWeights && SuccHasWeights) {
|
||||
// PBI: br i1 %x, TrueDest, BB
|
||||
// BI: br i1 %y, TrueDest, FalseDest
|
||||
//TrueWeight is TrueWeight for PBI * TotalWeight for BI +
|
||||
// FalseWeight for PBI * TrueWeight for BI.
|
||||
NewWeights.push_back(PredTrueWeight * (SuccFalseWeight +
|
||||
SuccTrueWeight) + PredFalseWeight * SuccTrueWeight);
|
||||
//FalseWeight is FalseWeight for PBI * FalseWeight for BI.
|
||||
NewWeights.push_back(PredFalseWeight * SuccFalseWeight);
|
||||
}
|
||||
AddPredecessorToBlock(FalseDest, PredBlock, BB);
|
||||
PBI->setSuccessor(1, FalseDest);
|
||||
}
|
||||
if (NewWeights.size() == 2) {
|
||||
// Halve the weights if any of them cannot fit in an uint32_t
|
||||
FitWeights(NewWeights);
|
||||
|
||||
SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(),NewWeights.end());
|
||||
PBI->setMetadata(LLVMContext::MD_prof,
|
||||
MDBuilder(BI->getContext()).
|
||||
createBranchWeights(MDWeights));
|
||||
} else
|
||||
PBI->setMetadata(LLVMContext::MD_prof, NULL);
|
||||
} else {
|
||||
// Update PHI nodes in the common successors.
|
||||
for (unsigned i = 0, e = PHIs.size(); i != e; ++i) {
|
||||
@ -1981,90 +1996,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) {
|
||||
// TODO: If BB is reachable from all paths through PredBlock, then we
|
||||
// could replace PBI's branch probabilities with BI's.
|
||||
|
||||
// Merge probability data into PredBlock's branch.
|
||||
APInt A, B, C, D;
|
||||
if (PBI->isConditional() && BI->isConditional() &&
|
||||
ExtractBranchMetadata(PBI, C, D) && ExtractBranchMetadata(BI, A, B)) {
|
||||
// Given IR which does:
|
||||
// bbA:
|
||||
// br i1 %x, label %bbB, label %bbC
|
||||
// bbB:
|
||||
// br i1 %y, label %bbD, label %bbC
|
||||
// Let's call the probability that we take the edge from %bbA to %bbB
|
||||
// 'a', from %bbA to %bbC, 'b', from %bbB to %bbD 'c' and from %bbB to
|
||||
// %bbC probability 'd'.
|
||||
//
|
||||
// We transform the IR into:
|
||||
// bbA:
|
||||
// br i1 %z, label %bbD, label %bbC
|
||||
// where the probability of going to %bbD is (a*c) and going to bbC is
|
||||
// (b+a*d).
|
||||
//
|
||||
// Probabilities aren't stored as ratios directly. Using branch weights,
|
||||
// we get:
|
||||
// (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.
|
||||
|
||||
// In the event of overflow, we want to drop the LSB of the input
|
||||
// probabilities.
|
||||
unsigned BitsLost;
|
||||
|
||||
// Ignore overflow result on ProbTrue.
|
||||
APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost);
|
||||
|
||||
APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost);
|
||||
if (BitsLost) {
|
||||
ProbTrue = ProbTrue.lshr(BitsLost*2);
|
||||
}
|
||||
|
||||
APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost);
|
||||
if (BitsLost) {
|
||||
ProbTrue = ProbTrue.lshr(BitsLost*2);
|
||||
Tmp1 = Tmp1.lshr(BitsLost*2);
|
||||
}
|
||||
|
||||
APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost);
|
||||
if (BitsLost) {
|
||||
ProbTrue = ProbTrue.lshr(BitsLost*2);
|
||||
Tmp1 = Tmp1.lshr(BitsLost*2);
|
||||
Tmp2 = Tmp2.lshr(BitsLost*2);
|
||||
}
|
||||
|
||||
bool Overflow1 = false, Overflow2 = false;
|
||||
APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1);
|
||||
APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2);
|
||||
|
||||
if (Overflow1 || Overflow2) {
|
||||
ProbTrue = ProbTrue.lshr(1);
|
||||
Tmp1 = Tmp1.lshr(1);
|
||||
Tmp2 = Tmp2.lshr(1);
|
||||
Tmp3 = Tmp3.lshr(1);
|
||||
Tmp4 = Tmp2 + Tmp3;
|
||||
ProbFalse = Tmp4 + Tmp1;
|
||||
}
|
||||
|
||||
// The sum of branch weights must fit in 32-bits.
|
||||
if (ProbTrue.isNegative() && ProbFalse.isNegative()) {
|
||||
ProbTrue = ProbTrue.lshr(1);
|
||||
ProbFalse = ProbFalse.lshr(1);
|
||||
}
|
||||
|
||||
if (ProbTrue != ProbFalse) {
|
||||
// Normalize the result.
|
||||
APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
|
||||
ProbTrue = ProbTrue.udiv(GCD);
|
||||
ProbFalse = ProbFalse.udiv(GCD);
|
||||
|
||||
MDBuilder MDB(BI->getContext());
|
||||
MDNode *N = MDB.createBranchWeights(ProbTrue.getZExtValue(),
|
||||
ProbFalse.getZExtValue());
|
||||
PBI->setMetadata(LLVMContext::MD_prof, N);
|
||||
} else {
|
||||
PBI->setMetadata(LLVMContext::MD_prof, NULL);
|
||||
}
|
||||
} else {
|
||||
PBI->setMetadata(LLVMContext::MD_prof, NULL);
|
||||
}
|
||||
|
||||
// Copy any debug value intrinsics into the end of PredBlock.
|
||||
for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
|
||||
if (isa<DbgInfoIntrinsic>(*I))
|
||||
|
@ -154,6 +154,26 @@ sw.epilog:
|
||||
ret void
|
||||
}
|
||||
|
||||
;; This test is based on test1 but swapped the targets of the second branch.
|
||||
define void @test1_swap(i1 %a, i1 %b) {
|
||||
; CHECK: @test1_swap
|
||||
entry:
|
||||
br i1 %a, label %Y, label %X, !prof !0
|
||||
; CHECK: br i1 %or.cond, label %Y, label %Z, !prof !4
|
||||
|
||||
X:
|
||||
%c = or i1 %b, false
|
||||
br i1 %c, label %Y, label %Z, !prof !1
|
||||
|
||||
Y:
|
||||
call void @helper(i32 0)
|
||||
ret void
|
||||
|
||||
Z:
|
||||
call void @helper(i32 1)
|
||||
ret void
|
||||
}
|
||||
|
||||
!0 = metadata !{metadata !"branch_weights", i32 3, i32 5}
|
||||
!1 = metadata !{metadata !"branch_weights", i32 1, i32 1}
|
||||
!2 = metadata !{metadata !"branch_weights", i32 1, i32 2}
|
||||
@ -165,4 +185,5 @@ sw.epilog:
|
||||
; CHECK: !1 = metadata !{metadata !"branch_weights", i32 1, i32 5}
|
||||
; CHECK: !2 = metadata !{metadata !"branch_weights", i32 7, i32 1, i32 2}
|
||||
; CHECK: !3 = metadata !{metadata !"branch_weights", i32 49, i32 12, i32 24, i32 35}
|
||||
; CHECK-NOT: !4
|
||||
; CHECK: !4 = metadata !{metadata !"branch_weights", i32 11, i32 5}
|
||||
; CHECK-NOT: !5
|
||||
|
Loading…
Reference in New Issue
Block a user