Extend the IndVarSimplify support for promoting induction variables:

- Test for signed and unsigned wrapping conditions, instead of just
   testing for non-negative induction ranges. 
 - Handle loops with GT comparisons, in addition to LT comparisons.
 - Support more cases of induction variables that don't start at 0.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@64532 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Dan Gohman 2009-02-14 02:31:09 +00:00
parent f108e2eaaa
commit aa03649af2
2 changed files with 172 additions and 44 deletions

View File

@ -458,33 +458,98 @@ static const Type *getEffectiveIndvarType(const PHINode *Phi) {
return Ty;
}
/// isOrigIVAlwaysNonNegative - Analyze the original induction variable
/// in the loop to determine whether it would ever have a negative
/// value.
/// TestOrigIVForWrap - Analyze the original induction variable
/// in the loop to determine whether it would ever undergo signed
/// or unsigned overflow.
///
/// TODO: This duplicates a fair amount of ScalarEvolution logic.
/// Perhaps this can be merged with ScalarEvolution::getIterationCount.
/// Perhaps this can be merged with ScalarEvolution::getIterationCount
/// and/or ScalarEvolution::get{Sign,Zero}ExtendExpr.
///
static bool isOrigIVAlwaysNonNegative(const Loop *L,
const Instruction *OrigCond) {
static void TestOrigIVForWrap(const Loop *L,
const BranchInst *BI,
const Instruction *OrigCond,
bool &NoSignedWrap,
bool &NoUnsignedWrap) {
// Verify that the loop is sane and find the exit condition.
const ICmpInst *Cmp = dyn_cast<ICmpInst>(OrigCond);
if (!Cmp) return false;
if (!Cmp) return;
// For now, analyze only SLT loops for signed overflow.
if (Cmp->getPredicate() != ICmpInst::ICMP_SLT) return false;
const Value *CmpLHS = Cmp->getOperand(0);
const Value *CmpRHS = Cmp->getOperand(1);
const BasicBlock *TrueBB = BI->getSuccessor(0);
const BasicBlock *FalseBB = BI->getSuccessor(1);
ICmpInst::Predicate Pred = Cmp->getPredicate();
// Get the increment instruction. Look past SExtInsts if we will
// Canonicalize a constant to the RHS.
if (isa<ConstantInt>(CmpLHS)) {
Pred = ICmpInst::getSwappedPredicate(Pred);
std::swap(CmpLHS, CmpRHS);
}
// Canonicalize SLE to SLT.
if (Pred == ICmpInst::ICMP_SLE)
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
if (!CI->getValue().isMaxSignedValue()) {
CmpRHS = ConstantInt::get(CI->getValue() + 1);
Pred = ICmpInst::ICMP_SLT;
}
// Canonicalize SGT to SGE.
if (Pred == ICmpInst::ICMP_SGT)
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
if (!CI->getValue().isMaxSignedValue()) {
CmpRHS = ConstantInt::get(CI->getValue() + 1);
Pred = ICmpInst::ICMP_SGE;
}
// Canonicalize SGE to SLT.
if (Pred == ICmpInst::ICMP_SGE) {
std::swap(TrueBB, FalseBB);
Pred = ICmpInst::ICMP_SLT;
}
// Canonicalize ULE to ULT.
if (Pred == ICmpInst::ICMP_ULE)
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
if (!CI->getValue().isMaxValue()) {
CmpRHS = ConstantInt::get(CI->getValue() + 1);
Pred = ICmpInst::ICMP_ULT;
}
// Canonicalize UGT to UGE.
if (Pred == ICmpInst::ICMP_UGT)
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
if (!CI->getValue().isMaxValue()) {
CmpRHS = ConstantInt::get(CI->getValue() + 1);
Pred = ICmpInst::ICMP_UGE;
}
// Canonicalize UGE to ULT.
if (Pred == ICmpInst::ICMP_UGE) {
std::swap(TrueBB, FalseBB);
Pred = ICmpInst::ICMP_ULT;
}
// For now, analyze only LT loops for signed overflow.
if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_ULT)
return;
bool isSigned = Pred == ICmpInst::ICMP_SLT;
// Get the increment instruction. Look past casts if we will
// be able to prove that the original induction variable doesn't
// undergo signed overflow.
const Value *OrigIncrVal = Cmp->getOperand(0);
const Value *IncrVal = OrigIncrVal;
if (SExtInst *SI = dyn_cast<SExtInst>(Cmp->getOperand(0))) {
if (!isa<ConstantInt>(Cmp->getOperand(1)) ||
!cast<ConstantInt>(Cmp->getOperand(1))->getValue()
.isSignedIntN(IncrVal->getType()->getPrimitiveSizeInBits()))
return false;
IncrVal = SI->getOperand(0);
// undergo signed or unsigned overflow, respectively.
const Value *IncrVal = CmpLHS;
if (isSigned) {
if (const SExtInst *SI = dyn_cast<SExtInst>(CmpLHS)) {
if (!isa<ConstantInt>(CmpRHS) ||
!cast<ConstantInt>(CmpRHS)->getValue()
.isSignedIntN(IncrVal->getType()->getPrimitiveSizeInBits()))
return;
IncrVal = SI->getOperand(0);
}
} else {
if (const ZExtInst *ZI = dyn_cast<ZExtInst>(CmpLHS)) {
if (!isa<ConstantInt>(CmpRHS) ||
!cast<ConstantInt>(CmpRHS)->getValue()
.isIntN(IncrVal->getType()->getPrimitiveSizeInBits()))
return;
IncrVal = ZI->getOperand(0);
}
}
// For now, only analyze induction variables that have simple increments.
@ -493,32 +558,36 @@ static bool isOrigIVAlwaysNonNegative(const Loop *L,
IncrOp->getOpcode() != Instruction::Add ||
!isa<ConstantInt>(IncrOp->getOperand(1)) ||
!cast<ConstantInt>(IncrOp->getOperand(1))->equalsInt(1))
return false;
return;
// Make sure the PHI looks like a normal IV.
const PHINode *PN = dyn_cast<PHINode>(IncrOp->getOperand(0));
if (!PN || PN->getNumIncomingValues() != 2)
return false;
return;
unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
unsigned BackEdge = !IncomingEdge;
if (!L->contains(PN->getIncomingBlock(BackEdge)) ||
PN->getIncomingValue(BackEdge) != IncrOp)
return false;
return;
if (!L->contains(TrueBB))
return;
// For now, only analyze loops with a constant start value, so that
// we can easily determine if the start value is non-negative and
// not a maximum value which would wrap on the first iteration.
// we can easily determine if the start value is not a maximum value
// which would wrap on the first iteration.
const Value *InitialVal = PN->getIncomingValue(IncomingEdge);
if (!isa<ConstantInt>(InitialVal) ||
cast<ConstantInt>(InitialVal)->getValue().isNegative() ||
cast<ConstantInt>(InitialVal)->getValue().isMaxSignedValue())
return false;
if (!isa<ConstantInt>(InitialVal))
return;
// The original induction variable will start at some non-negative
// non-max value, it counts up by one, and the loop iterates only
// while it remans less than (signed) some value in the same type.
// As such, it will always be non-negative.
return true;
// The original induction variable will start at some non-max value,
// it counts up by one, and the loop iterates only while it remans
// less than some value in the same type. As such, it will never wrap.
if (isSigned &&
!cast<ConstantInt>(InitialVal)->getValue().isMaxSignedValue())
NoSignedWrap = true;
else if (!isSigned &&
!cast<ConstantInt>(InitialVal)->getValue().isMaxValue())
NoUnsignedWrap = true;
}
bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
@ -596,13 +665,15 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
// If we have a trip count expression, rewrite the loop's exit condition
// using it. We can currently only handle loops with a single exit.
bool OrigIVAlwaysNonNegative = false;
bool NoSignedWrap = false;
bool NoUnsignedWrap = false;
if (!isa<SCEVCouldNotCompute>(IterationCount) && ExitingBlock)
// Can't rewrite non-branch yet.
if (BranchInst *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator())) {
if (Instruction *OrigCond = dyn_cast<Instruction>(BI->getCondition())) {
// Determine if the OrigIV will ever have a non-zero sign bit.
OrigIVAlwaysNonNegative = isOrigIVAlwaysNonNegative(L, OrigCond);
// Determine if the OrigIV will ever undergo overflow.
TestOrigIVForWrap(L, BI, OrigCond,
NoSignedWrap, NoUnsignedWrap);
// We'll be replacing the original condition, so it'll be dead.
DeadInsts.insert(OrigCond);
@ -642,19 +713,38 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
/// If the new canonical induction variable is wider than the original,
/// and the original has uses that are casts to wider types, see if the
/// truncate and extend can be omitted.
if (isa<TruncInst>(NewVal))
if (PN->getType() != LargestType)
for (Value::use_iterator UI = PN->use_begin(), UE = PN->use_end();
UI != UE; ++UI)
if (isa<ZExtInst>(UI) ||
(isa<SExtInst>(UI) && OrigIVAlwaysNonNegative)) {
Value *TruncIndVar = IndVar;
if (TruncIndVar->getType() != UI->getType())
TruncIndVar = new TruncInst(IndVar, UI->getType(), "truncindvar",
InsertPt);
UI != UE; ++UI) {
if (isa<SExtInst>(UI) && NoSignedWrap) {
SCEVHandle ExtendedStart =
SE->getSignExtendExpr(cast<SCEVAddRecExpr>(IndVars.back().second)->getStart(), LargestType);
SCEVHandle ExtendedStep =
SE->getSignExtendExpr(cast<SCEVAddRecExpr>(IndVars.back().second)->getStepRecurrence(*SE), LargestType);
SCEVHandle ExtendedAddRec =
SE->getAddRecExpr(ExtendedStart, ExtendedStep, L);
if (LargestType != UI->getType())
ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, UI->getType());
Value *TruncIndVar = Rewriter.expandCodeFor(ExtendedAddRec, InsertPt);
UI->replaceAllUsesWith(TruncIndVar);
if (Instruction *DeadUse = dyn_cast<Instruction>(*UI))
DeadInsts.insert(DeadUse);
}
if (isa<ZExtInst>(UI) && NoUnsignedWrap) {
SCEVHandle ExtendedStart =
SE->getZeroExtendExpr(cast<SCEVAddRecExpr>(IndVars.back().second)->getStart(), LargestType);
SCEVHandle ExtendedStep =
SE->getZeroExtendExpr(cast<SCEVAddRecExpr>(IndVars.back().second)->getStepRecurrence(*SE), LargestType);
SCEVHandle ExtendedAddRec =
SE->getAddRecExpr(ExtendedStart, ExtendedStep, L);
if (LargestType != UI->getType())
ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, UI->getType());
Value *TruncIndVar = Rewriter.expandCodeFor(ExtendedAddRec, InsertPt);
UI->replaceAllUsesWith(TruncIndVar);
if (Instruction *DeadUse = dyn_cast<Instruction>(*UI))
DeadInsts.insert(DeadUse);
}
}
// Replace the old PHI Node with the inserted computation.
PN->replaceAllUsesWith(NewVal);

View File

@ -60,3 +60,41 @@ bb1.return_crit_edge: ; preds = %bb1
return: ; preds = %bb1.return_crit_edge, %entry
ret void
}
; Test cases from PR1301:
define void @kinds__srangezero([21 x i32]* nocapture %a) nounwind {
bb.thread:
br label %bb
bb: ; preds = %bb, %bb.thread
%i.0.reg2mem.0 = phi i8 [ -10, %bb.thread ], [ %tmp7, %bb ] ; <i8> [#uses=2]
%tmp12 = sext i8 %i.0.reg2mem.0 to i32 ; <i32> [#uses=1]
%tmp4 = add i32 %tmp12, 10 ; <i32> [#uses=1]
%tmp5 = getelementptr [21 x i32]* %a, i32 0, i32 %tmp4 ; <i32*> [#uses=1]
store i32 0, i32* %tmp5
%tmp7 = add i8 %i.0.reg2mem.0, 1 ; <i8> [#uses=2]
%0 = icmp sgt i8 %tmp7, 10 ; <i1> [#uses=1]
br i1 %0, label %return, label %bb
return: ; preds = %bb
ret void
}
define void @kinds__urangezero([21 x i32]* nocapture %a) nounwind {
bb.thread:
br label %bb
bb: ; preds = %bb, %bb.thread
%i.0.reg2mem.0 = phi i8 [ 10, %bb.thread ], [ %tmp7, %bb ] ; <i8> [#uses=2]
%tmp12 = sext i8 %i.0.reg2mem.0 to i32 ; <i32> [#uses=1]
%tmp4 = add i32 %tmp12, -10 ; <i32> [#uses=1]
%tmp5 = getelementptr [21 x i32]* %a, i32 0, i32 %tmp4 ; <i32*> [#uses=1]
store i32 0, i32* %tmp5
%tmp7 = add i8 %i.0.reg2mem.0, 1 ; <i8> [#uses=2]
%0 = icmp sgt i8 %tmp7, 30 ; <i1> [#uses=1]
br i1 %0, label %return, label %bb
return: ; preds = %bb
ret void
}