diff --git a/include/llvm/Analysis/ScalarEvolution.h b/include/llvm/Analysis/ScalarEvolution.h index 6babf01e6c9..6e2e359f42b 100644 --- a/include/llvm/Analysis/ScalarEvolution.h +++ b/include/llvm/Analysis/ScalarEvolution.h @@ -1099,10 +1099,14 @@ namespace llvm { bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags); - /// Return true if More == (Less + C), where C is a constant. This is - /// intended to be used as a cheaper substitute for full SCEV subtraction. - bool computeConstantDifference(const SCEV *Less, const SCEV *More, - APInt &C); + /// Compute \p LHS - \p RHS and returns the result as an APInt if it is a + /// constant, and None if it isn't. + /// + /// This is intended to be a cheaper version of getMinusSCEV. We can be + /// frugal here since we just bail out of actually constructing and + /// canonicalizing an expression in the cases where the result isn't going + /// to be a constant. + Optional computeConstantDifference(const SCEV *LHS, const SCEV *RHS); /// Drop memoized information computed for S. void forgetMemoizedResults(const SCEV *S); diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index ce99f82cc06..89c09a829cb 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -8275,9 +8275,8 @@ bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, return true; } -bool ScalarEvolution::computeConstantDifference(const SCEV *Less, - const SCEV *More, - APInt &C) { +Optional ScalarEvolution::computeConstantDifference(const SCEV *More, + const SCEV *Less) { // We avoid subtracting expressions here because this function is usually // fairly deep in the call stack (i.e. is called many times). @@ -8286,15 +8285,15 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less, const auto *MAR = cast(More); if (LAR->getLoop() != MAR->getLoop()) - return false; + return None; // We look at affine expressions only; not for correctness but to keep // getStepRecurrence cheap. if (!LAR->isAffine() || !MAR->isAffine()) - return false; + return None; if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) - return false; + return None; Less = LAR->getStart(); More = MAR->getStart(); @@ -8305,27 +8304,22 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less, if (isa(Less) && isa(More)) { const auto &M = cast(More)->getAPInt(); const auto &L = cast(Less)->getAPInt(); - C = M - L; - return true; + return M - L; } const SCEV *L, *R; SCEV::NoWrapFlags Flags; if (splitBinaryAdd(Less, L, R, Flags)) if (const auto *LC = dyn_cast(L)) - if (R == More) { - C = -(LC->getAPInt()); - return true; - } + if (R == More) + return -(LC->getAPInt()); if (splitBinaryAdd(More, L, R, Flags)) if (const auto *LC = dyn_cast(L)) - if (R == Less) { - C = LC->getAPInt(); - return true; - } + if (R == Less) + return LC->getAPInt(); - return false; + return None; } bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( @@ -8382,22 +8376,21 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + // C)". - APInt LDiff, RDiff; - if (!computeConstantDifference(FoundLHS, LHS, LDiff) || - !computeConstantDifference(FoundRHS, RHS, RDiff) || - LDiff != RDiff) + Optional LDiff = computeConstantDifference(LHS, FoundLHS); + Optional RDiff = computeConstantDifference(RHS, FoundRHS); + if (!LDiff || !RDiff || *LDiff != *RDiff) return false; - if (LDiff == 0) + if (LDiff->isMinValue()) return true; APInt FoundRHSLimit; if (Pred == CmpInst::ICMP_ULT) { - FoundRHSLimit = -RDiff; + FoundRHSLimit = -(*RDiff); } else { assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); - FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - RDiff; + FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff; } // Try to prove (1) or (2), as needed.