diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 326bd7a2bd0..d5ae609679f 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1466,6 +1466,29 @@ static bool ExtractBranchMetadata(BranchInst *BI, return true; } +/// MultiplyAndLosePrecision - Multiplies A and B, then returns the result. In +/// the event of overflow, logically-shifts all four inputs right until the +/// multiply fits. +static APInt MultiplyAndLosePrecision(APInt &A, APInt &B, APInt &C, APInt &D, + unsigned &BitsLost) { + BitsLost = 0; + bool Overflow = false; + APInt Result = A.umul_ov(B, Overflow); + if (Overflow) { + APInt MaxB = APInt::getMaxValue(A.getBitWidth()).udiv(A); + do { + B = B.lshr(1); + ++BitsLost; + } while (B.ugt(MaxB)); + A = A.lshr(BitsLost); + C = C.lshr(BitsLost); + D = D.lshr(BitsLost); + Result = A * B; + } + return Result; +} + + /// FoldBranchToCommonDest - If this basic block is simple enough, and if a /// predecessor branches to us and one of our successors, fold the block into /// the predecessor and use logical operations to pick the right destination. @@ -1665,32 +1688,64 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) { // we get: // (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D. - bool Overflow1 = false, Overflow2 = false, Overflow3 = false; - bool Overflow4 = false, Overflow5 = false, Overflow6 = false; - APInt ProbTrue = A.umul_ov(C, Overflow1); + // In the event of overflow, we want to drop the LSB of the input + // probabilities. + unsigned BitsLost; - APInt Tmp1 = A.umul_ov(D, Overflow2); - APInt Tmp2 = B.umul_ov(C, Overflow3); - APInt Tmp3 = B.umul_ov(D, Overflow4); - APInt Tmp4 = Tmp1.uadd_ov(Tmp2, Overflow5); - APInt ProbFalse = Tmp4.uadd_ov(Tmp3, Overflow6); + // Ignore overflow result on ProbTrue. + APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost); - APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse); - ProbTrue = ProbTrue.udiv(GCD); - ProbFalse = ProbFalse.udiv(GCD); + APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost); + if (BitsLost) { + ProbTrue = ProbTrue.lshr(BitsLost*2); + } + + APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost); + if (BitsLost) { + ProbTrue = ProbTrue.lshr(BitsLost*2); + Tmp1 = Tmp1.lshr(BitsLost*2); + } + + APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost); + if (BitsLost) { + ProbTrue = ProbTrue.lshr(BitsLost*2); + Tmp1 = Tmp1.lshr(BitsLost*2); + Tmp2 = Tmp2.lshr(BitsLost*2); + } + + bool Overflow1 = false, Overflow2 = false; + APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1); + APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2); + + if (Overflow1 || Overflow2) { + ProbTrue = ProbTrue.lshr(1); + Tmp1 = Tmp1.lshr(1); + Tmp2 = Tmp2.lshr(1); + Tmp3 = Tmp3.lshr(1); + Tmp4 = Tmp2 + Tmp3; + ProbFalse = Tmp4 + Tmp1; + } + + // The sum of branch weights must fit in 32-bits. + if (ProbTrue.isNegative() && ProbFalse.isNegative()) { + ProbTrue = ProbTrue.lshr(1); + ProbFalse = ProbFalse.lshr(1); + } + + if (ProbTrue != ProbFalse) { + // Normalize the result. + APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse); + ProbTrue = ProbTrue.udiv(GCD); + ProbFalse = ProbFalse.udiv(GCD); - if (Overflow1 || Overflow2 || Overflow3 || Overflow4 || Overflow5 || - Overflow6) { - DEBUG(dbgs() << "Overflow recomputing branch weight on: " << *PBI - << "when merging with: " << *BI); - PBI->setMetadata(LLVMContext::MD_prof, NULL); - } else { LLVMContext &Context = BI->getContext(); Value *Ops[3]; Ops[0] = BI->getMetadata(LLVMContext::MD_prof)->getOperand(0); Ops[1] = ConstantInt::get(Context, ProbTrue); Ops[2] = ConstantInt::get(Context, ProbFalse); PBI->setMetadata(LLVMContext::MD_prof, MDNode::get(Context, Ops)); + } else { + PBI->setMetadata(LLVMContext::MD_prof, NULL); } } else { PBI->setMetadata(LLVMContext::MD_prof, NULL);