Add several more icmp simplifications. Transform signed comparisons

into unsigned ones when the operands are known to have the same
sign bit value.

llvm-svn: 70053
This commit is contained in:
Dan Gohman 2009-04-25 17:12:48 +00:00
parent 60349bb21e
commit a7fae1f865
2 changed files with 187 additions and 82 deletions

View File

@ -708,15 +708,13 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
// set of known zero and one bits, compute the maximum and minimum values that // set of known zero and one bits, compute the maximum and minimum values that
// could have the specified known zero and known one bits, returning them in // could have the specified known zero and known one bits, returning them in
// min/max. // min/max.
static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty, static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero,
const APInt& KnownZero,
const APInt& KnownOne, const APInt& KnownOne,
APInt& Min, APInt& Max) { APInt& Min, APInt& Max) {
uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth(); assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() &&
assert(KnownZero.getBitWidth() == BitWidth && KnownZero.getBitWidth() == Min.getBitWidth() &&
KnownOne.getBitWidth() == BitWidth && KnownZero.getBitWidth() == Max.getBitWidth() &&
Min.getBitWidth() == BitWidth && Max.getBitWidth() == BitWidth && "KnownZero, KnownOne and Min, Max must have equal bitwidth.");
"Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
APInt UnknownBits = ~(KnownZero|KnownOne); APInt UnknownBits = ~(KnownZero|KnownOne);
// The minimum value is when all unknown bits are zeros, EXCEPT for the sign // The minimum value is when all unknown bits are zeros, EXCEPT for the sign
@ -724,9 +722,9 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
Min = KnownOne; Min = KnownOne;
Max = KnownOne|UnknownBits; Max = KnownOne|UnknownBits;
if (UnknownBits[BitWidth-1]) { // Sign bit is unknown if (UnknownBits.isNegative()) { // Sign bit is unknown
Min.set(BitWidth-1); Min.set(Min.getBitWidth()-1);
Max.clear(BitWidth-1); Max.clear(Max.getBitWidth()-1);
} }
} }
@ -734,14 +732,12 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
// a set of known zero and one bits, compute the maximum and minimum values that // a set of known zero and one bits, compute the maximum and minimum values that
// could have the specified known zero and known one bits, returning them in // could have the specified known zero and known one bits, returning them in
// min/max. // min/max.
static void ComputeUnsignedMinMaxValuesFromKnownBits(const Type *Ty, static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero,
const APInt &KnownZero,
const APInt &KnownOne, const APInt &KnownOne,
APInt &Min, APInt &Max) { APInt &Min, APInt &Max) {
uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth(); BitWidth = BitWidth; assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() &&
assert(KnownZero.getBitWidth() == BitWidth && KnownZero.getBitWidth() == Min.getBitWidth() &&
KnownOne.getBitWidth() == BitWidth && KnownZero.getBitWidth() == Max.getBitWidth() &&
Min.getBitWidth() == BitWidth && Max.getBitWidth() &&
"Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth."); "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
APInt UnknownBits = ~(KnownZero|KnownOne); APInt UnknownBits = ~(KnownZero|KnownOne);
@ -808,9 +804,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(V != 0 && "Null pointer of Value???"); assert(V != 0 && "Null pointer of Value???");
assert(Depth <= 6 && "Limit Search Depth"); assert(Depth <= 6 && "Limit Search Depth");
uint32_t BitWidth = DemandedMask.getBitWidth(); uint32_t BitWidth = DemandedMask.getBitWidth();
const IntegerType *VTy = cast<IntegerType>(V->getType()); const Type *VTy = V->getType();
assert(VTy->getBitWidth() == BitWidth && assert((TD || !isa<PointerType>(VTy)) &&
KnownZero.getBitWidth() == BitWidth && "SimplifyDemandedBits needs to know bit widths!");
assert((!TD || TD->getTypeSizeInBits(VTy) == BitWidth) &&
(!isa<IntegerType>(VTy) ||
VTy->getPrimitiveSizeInBits() == BitWidth) &&
KnownZero.getBitWidth() == BitWidth &&
KnownOne.getBitWidth() == BitWidth && KnownOne.getBitWidth() == BitWidth &&
"Value *V, DemandedMask, KnownZero and KnownOne \ "Value *V, DemandedMask, KnownZero and KnownOne \
must have same BitWidth"); must have same BitWidth");
@ -820,7 +820,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
KnownZero = ~KnownOne & DemandedMask; KnownZero = ~KnownOne & DemandedMask;
return 0; return 0;
} }
if (isa<ConstantPointerNull>(V)) {
// We know all of the bits for a constant!
KnownOne.clear();
KnownZero = DemandedMask;
return 0;
}
KnownZero.clear(); KnownZero.clear();
KnownOne.clear(); KnownOne.clear();
if (DemandedMask == 0) { // Not demanding any bits from V. if (DemandedMask == 0) { // Not demanding any bits from V.
@ -832,12 +838,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 6) // Limit search depth. if (Depth == 6) // Limit search depth.
return 0; return 0;
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return 0; // Only analyze instructions.
APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
APInt &RHSKnownZero = KnownZero, &RHSKnownOne = KnownOne; APInt &RHSKnownZero = KnownZero, &RHSKnownOne = KnownOne;
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
ComputeMaskedBits(V, DemandedMask, RHSKnownZero, RHSKnownOne, Depth);
return 0; // Only analyze instructions.
}
// If there are multiple uses of this value and we aren't at the root, then // If there are multiple uses of this value and we aren't at the root, then
// we can't do any simplifications of the operands, because DemandedMask // we can't do any simplifications of the operands, because DemandedMask
// only reflects the bits demanded by *one* of the users. // only reflects the bits demanded by *one* of the users.
@ -1399,8 +1408,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the client is only demanding bits that we know, return the known // If the client is only demanding bits that we know, return the known
// constant. // constant.
if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) {
return ConstantInt::get(RHSKnownOne); Constant *C = ConstantInt::get(RHSKnownOne);
if (isa<PointerType>(V->getType()))
C = ConstantExpr::getIntToPtr(C, V->getType());
return C;
}
return false; return false;
} }
@ -5831,6 +5844,14 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
} }
} }
unsigned BitWidth = 0;
if (TD)
BitWidth = TD->getTypeSizeInBits(Ty);
else if (isa<IntegerType>(Ty))
BitWidth = Ty->getPrimitiveSizeInBits();
bool isSignBit = false;
// See if we are doing a comparison with a constant. // See if we are doing a comparison with a constant.
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
Value *A = 0, *B = 0; Value *A = 0, *B = 0;
@ -5865,105 +5886,161 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI)); return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI));
} }
// See if we can fold the comparison based on range information we can get
// by checking whether bits are known to be zero or one in the input.
uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
// If this comparison is a normal comparison, it demands all // If this comparison is a normal comparison, it demands all
// bits, if it is a sign bit comparison, it only demands the sign bit. // bits, if it is a sign bit comparison, it only demands the sign bit.
bool UnusedBit; bool UnusedBit;
bool isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit);
}
if (SimplifyDemandedBits(I.getOperandUse(0),
// See if we can fold the comparison based on range information we can get
// by checking whether bits are known to be zero or one in the input.
if (BitWidth != 0) {
APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
if (SimplifyDemandedBits(I.getOperandUse(0),
isSignBit ? APInt::getSignBit(BitWidth) isSignBit ? APInt::getSignBit(BitWidth)
: APInt::getAllOnesValue(BitWidth), : APInt::getAllOnesValue(BitWidth),
KnownZero, KnownOne, 0)) Op0KnownZero, Op0KnownOne, 0))
return &I; return &I;
if (SimplifyDemandedBits(I.getOperandUse(1),
APInt::getAllOnesValue(BitWidth),
Op1KnownZero, Op1KnownOne, 0))
return &I;
// Given the known and unknown bits, compute a range that the LHS could be // Given the known and unknown bits, compute a range that the LHS could be
// in. Compute the Min, Max and RHS values based on the known bits. For the // in. Compute the Min, Max and RHS values based on the known bits. For the
// EQ and NE we use unsigned values. // EQ and NE we use unsigned values.
APInt Min(BitWidth, 0), Max(BitWidth, 0); APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
if (ICmpInst::isSignedPredicate(I.getPredicate())) APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
ComputeSignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne, Min, Max); if (ICmpInst::isSignedPredicate(I.getPredicate())) {
else ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
ComputeUnsignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne,Min,Max); Op0Min, Op0Max);
ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
Op1Min, Op1Max);
} else {
ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
Op0Min, Op0Max);
ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
Op1Min, Op1Max);
}
// If Min and Max are known to be the same, then SimplifyDemandedBits // If Min and Max are known to be the same, then SimplifyDemandedBits
// figured out that the LHS is a constant. Just constant fold this now so // figured out that the LHS is a constant. Just constant fold this now so
// that code below can assume that Min != Max. // that code below can assume that Min != Max.
if (Min == Max) if (!isa<Constant>(Op0) && Op0Min == Op0Max)
return ReplaceInstUsesWith(I, ConstantExpr::getICmp(I.getPredicate(), return new ICmpInst(I.getPredicate(), ConstantInt::get(Op0Min), Op1);
ConstantInt::get(Min), if (!isa<Constant>(Op1) && Op1Min == Op1Max)
CI)); return new ICmpInst(I.getPredicate(), Op0, ConstantInt::get(Op1Min));
// Based on the range information we know about the LHS, see if we can // Based on the range information we know about the LHS, see if we can
// simplify this comparison. For example, (x&4) < 8 is always true. // simplify this comparison. For example, (x&4) < 8 is always true.
const APInt &RHSVal = CI->getValue(); switch (I.getPredicate()) {
switch (I.getPredicate()) { // LE/GE have been folded already.
default: assert(0 && "Unknown icmp opcode!"); default: assert(0 && "Unknown icmp opcode!");
case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_EQ:
if (Max.ult(RHSVal) || Min.ugt(RHSVal)) if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
return ReplaceInstUsesWith(I, ConstantInt::getFalse()); return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break; break;
case ICmpInst::ICMP_NE: case ICmpInst::ICMP_NE:
if (Max.ult(RHSVal) || Min.ugt(RHSVal)) if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
return ReplaceInstUsesWith(I, ConstantInt::getTrue()); return ReplaceInstUsesWith(I, ConstantInt::getTrue());
break; break;
case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULT:
if (Max.ult(RHSVal)) // A <u C -> true iff max(A) < C if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue()); return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Min.uge(RHSVal)) // A <u C -> false iff min(A) >= C if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse()); return ReplaceInstUsesWith(I, ConstantInt::getFalse());
if (RHSVal == Max) // A <u MAX -> A != MAX if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
if (RHSVal == Min+1) // A <u MIN+1 -> A == MIN if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); if (Op1Max == Op0Min+1) // A <u C -> A == C-1 if min(A)+1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
// (x <u 2147483648) -> (x >s -1) -> true if sign bit clear
if (CI->isMinValue(true)) // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear
return new ICmpInst(ICmpInst::ICMP_SGT, Op0, if (CI->isMinValue(true))
return new ICmpInst(ICmpInst::ICMP_SGT, Op0,
ConstantInt::getAllOnesValue(Op0->getType())); ConstantInt::getAllOnesValue(Op0->getType()));
}
break; break;
case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGT:
if (Min.ugt(RHSVal)) // A >u C -> true iff min(A) > C if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue()); return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Max.ule(RHSVal)) // A >u C -> false iff max(A) <= C if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse()); return ReplaceInstUsesWith(I, ConstantInt::getFalse());
if (RHSVal == Min) // A >u MIN -> A != MIN if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
if (RHSVal == Max-1) // A >u MAX-1 -> A == MAX if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); if (Op1Min == Op0Max-1) // A >u C -> A == C+1 if max(a)-1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
// (x >u 2147483647) -> (x <s 0) -> true if sign bit set
if (CI->isMaxValue(true)) // (x >u 2147483647) -> (x <s 0) -> true if sign bit set
return new ICmpInst(ICmpInst::ICMP_SLT, Op0, if (CI->isMaxValue(true))
ConstantInt::getNullValue(Op0->getType())); return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
ConstantInt::getNullValue(Op0->getType()));
}
break; break;
case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLT:
if (Max.slt(RHSVal)) // A <s C -> true iff max(A) < C if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
return ReplaceInstUsesWith(I, ConstantInt::getTrue()); return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Min.sge(RHSVal)) // A <s C -> false iff min(A) >= C if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
return ReplaceInstUsesWith(I, ConstantInt::getFalse()); return ReplaceInstUsesWith(I, ConstantInt::getFalse());
if (RHSVal == Max) // A <s MAX -> A != MAX if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
if (RHSVal == Min+1) // A <s MIN+1 -> A == MIN if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
}
break; break;
case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGT:
if (Min.sgt(RHSVal)) // A >s C -> true iff min(A) > C if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue()); return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Max.sle(RHSVal)) // A >s C -> false iff max(A) <= C if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse()); return ReplaceInstUsesWith(I, ConstantInt::getFalse());
if (RHSVal == Min) // A >s MIN -> A != MIN if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
if (RHSVal == Max-1) // A >s MAX-1 -> A == MAX if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
}
break;
case ICmpInst::ICMP_SGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
case ICmpInst::ICMP_SLE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
case ICmpInst::ICMP_UGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
case ICmpInst::ICMP_ULE:
assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break; break;
} }
// Turn a signed comparison into an unsigned one if both operands
// are known to have the same sign.
if (I.isSignedPredicate() &&
((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) ||
(Op0KnownOne.isNegative() && Op1KnownOne.isNegative())))
return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
} }
// Test if the ICmpInst instruction is used exclusively by a select as // Test if the ICmpInst instruction is used exclusively by a select as

View File

@ -0,0 +1,28 @@
; RUN: llvm-as < %s | opt -instcombine | llvm-dis > %t
; RUN: not grep zext %t
; RUN: not grep slt %t
; RUN: grep {icmp ult} %t
; Instcombine should convert the zext+slt into a simple ult.
define void @foo(double* %p) nounwind {
entry:
br label %bb
bb:
%indvar = phi i64 [ 0, %entry ], [ %indvar.next, %bb ]
%t0 = and i64 %indvar, 65535
%t1 = getelementptr double* %p, i64 %t0
%t2 = load double* %t1, align 8
%t3 = mul double %t2, 2.2
store double %t3, double* %t1, align 8
%i.04 = trunc i64 %indvar to i16
%t4 = add i16 %i.04, 1
%t5 = zext i16 %t4 to i32
%t6 = icmp slt i32 %t5, 500
%indvar.next = add i64 %indvar, 1
br i1 %t6, label %bb, label %return
return:
ret void
}