Add several more icmp simplifications. Transform signed comparisons

into unsigned ones when the operands are known to have the same
sign bit value.

llvm-svn: 70053
This commit is contained in:
Dan Gohman 2009-04-25 17:12:48 +00:00
parent 60349bb21e
commit a7fae1f865
2 changed files with 187 additions and 82 deletions

View File

@ -708,15 +708,13 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
// set of known zero and one bits, compute the maximum and minimum values that
// could have the specified known zero and known one bits, returning them in
// min/max.
static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
const APInt& KnownZero,
static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero,
const APInt& KnownOne,
APInt& Min, APInt& Max) {
uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth();
assert(KnownZero.getBitWidth() == BitWidth &&
KnownOne.getBitWidth() == BitWidth &&
Min.getBitWidth() == BitWidth && Max.getBitWidth() == BitWidth &&
"Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() &&
KnownZero.getBitWidth() == Min.getBitWidth() &&
KnownZero.getBitWidth() == Max.getBitWidth() &&
"KnownZero, KnownOne and Min, Max must have equal bitwidth.");
APInt UnknownBits = ~(KnownZero|KnownOne);
// The minimum value is when all unknown bits are zeros, EXCEPT for the sign
@ -724,9 +722,9 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
Min = KnownOne;
Max = KnownOne|UnknownBits;
if (UnknownBits[BitWidth-1]) { // Sign bit is unknown
Min.set(BitWidth-1);
Max.clear(BitWidth-1);
if (UnknownBits.isNegative()) { // Sign bit is unknown
Min.set(Min.getBitWidth()-1);
Max.clear(Max.getBitWidth()-1);
}
}
@ -734,14 +732,12 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
// a set of known zero and one bits, compute the maximum and minimum values that
// could have the specified known zero and known one bits, returning them in
// min/max.
static void ComputeUnsignedMinMaxValuesFromKnownBits(const Type *Ty,
const APInt &KnownZero,
static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero,
const APInt &KnownOne,
APInt &Min, APInt &Max) {
uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth(); BitWidth = BitWidth;
assert(KnownZero.getBitWidth() == BitWidth &&
KnownOne.getBitWidth() == BitWidth &&
Min.getBitWidth() == BitWidth && Max.getBitWidth() &&
assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() &&
KnownZero.getBitWidth() == Min.getBitWidth() &&
KnownZero.getBitWidth() == Max.getBitWidth() &&
"Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
APInt UnknownBits = ~(KnownZero|KnownOne);
@ -808,9 +804,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(V != 0 && "Null pointer of Value???");
assert(Depth <= 6 && "Limit Search Depth");
uint32_t BitWidth = DemandedMask.getBitWidth();
const IntegerType *VTy = cast<IntegerType>(V->getType());
assert(VTy->getBitWidth() == BitWidth &&
KnownZero.getBitWidth() == BitWidth &&
const Type *VTy = V->getType();
assert((TD || !isa<PointerType>(VTy)) &&
"SimplifyDemandedBits needs to know bit widths!");
assert((!TD || TD->getTypeSizeInBits(VTy) == BitWidth) &&
(!isa<IntegerType>(VTy) ||
VTy->getPrimitiveSizeInBits() == BitWidth) &&
KnownZero.getBitWidth() == BitWidth &&
KnownOne.getBitWidth() == BitWidth &&
"Value *V, DemandedMask, KnownZero and KnownOne \
must have same BitWidth");
@ -820,7 +820,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
KnownZero = ~KnownOne & DemandedMask;
return 0;
}
if (isa<ConstantPointerNull>(V)) {
// We know all of the bits for a constant!
KnownOne.clear();
KnownZero = DemandedMask;
return 0;
}
KnownZero.clear();
KnownOne.clear();
if (DemandedMask == 0) { // Not demanding any bits from V.
@ -832,12 +838,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 6) // Limit search depth.
return 0;
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return 0; // Only analyze instructions.
APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
APInt &RHSKnownZero = KnownZero, &RHSKnownOne = KnownOne;
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
ComputeMaskedBits(V, DemandedMask, RHSKnownZero, RHSKnownOne, Depth);
return 0; // Only analyze instructions.
}
// If there are multiple uses of this value and we aren't at the root, then
// we can't do any simplifications of the operands, because DemandedMask
// only reflects the bits demanded by *one* of the users.
@ -1399,8 +1408,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the client is only demanding bits that we know, return the known
// constant.
if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask)
return ConstantInt::get(RHSKnownOne);
if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) {
Constant *C = ConstantInt::get(RHSKnownOne);
if (isa<PointerType>(V->getType()))
C = ConstantExpr::getIntToPtr(C, V->getType());
return C;
}
return false;
}
@ -5831,6 +5844,14 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
}
}
unsigned BitWidth = 0;
if (TD)
BitWidth = TD->getTypeSizeInBits(Ty);
else if (isa<IntegerType>(Ty))
BitWidth = Ty->getPrimitiveSizeInBits();
bool isSignBit = false;
// See if we are doing a comparison with a constant.
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
Value *A = 0, *B = 0;
@ -5865,105 +5886,161 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI));
}
// See if we can fold the comparison based on range information we can get
// by checking whether bits are known to be zero or one in the input.
uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
// If this comparison is a normal comparison, it demands all
// bits, if it is a sign bit comparison, it only demands the sign bit.
bool UnusedBit;
bool isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit);
if (SimplifyDemandedBits(I.getOperandUse(0),
isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit);
}
// See if we can fold the comparison based on range information we can get
// by checking whether bits are known to be zero or one in the input.
if (BitWidth != 0) {
APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
if (SimplifyDemandedBits(I.getOperandUse(0),
isSignBit ? APInt::getSignBit(BitWidth)
: APInt::getAllOnesValue(BitWidth),
KnownZero, KnownOne, 0))
Op0KnownZero, Op0KnownOne, 0))
return &I;
if (SimplifyDemandedBits(I.getOperandUse(1),
APInt::getAllOnesValue(BitWidth),
Op1KnownZero, Op1KnownOne, 0))
return &I;
// Given the known and unknown bits, compute a range that the LHS could be
// in. Compute the Min, Max and RHS values based on the known bits. For the
// EQ and NE we use unsigned values.
APInt Min(BitWidth, 0), Max(BitWidth, 0);
if (ICmpInst::isSignedPredicate(I.getPredicate()))
ComputeSignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne, Min, Max);
else
ComputeUnsignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne,Min,Max);
APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
if (ICmpInst::isSignedPredicate(I.getPredicate())) {
ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
Op0Min, Op0Max);
ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
Op1Min, Op1Max);
} else {
ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
Op0Min, Op0Max);
ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
Op1Min, Op1Max);
}
// If Min and Max are known to be the same, then SimplifyDemandedBits
// figured out that the LHS is a constant. Just constant fold this now so
// that code below can assume that Min != Max.
if (Min == Max)
return ReplaceInstUsesWith(I, ConstantExpr::getICmp(I.getPredicate(),
ConstantInt::get(Min),
CI));
if (!isa<Constant>(Op0) && Op0Min == Op0Max)
return new ICmpInst(I.getPredicate(), ConstantInt::get(Op0Min), Op1);
if (!isa<Constant>(Op1) && Op1Min == Op1Max)
return new ICmpInst(I.getPredicate(), Op0, ConstantInt::get(Op1Min));
// Based on the range information we know about the LHS, see if we can
// simplify this comparison. For example, (x&4) < 8 is always true.
const APInt &RHSVal = CI->getValue();
switch (I.getPredicate()) { // LE/GE have been folded already.
switch (I.getPredicate()) {
default: assert(0 && "Unknown icmp opcode!");
case ICmpInst::ICMP_EQ:
if (Max.ult(RHSVal) || Min.ugt(RHSVal))
if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
case ICmpInst::ICMP_NE:
if (Max.ult(RHSVal) || Min.ugt(RHSVal))
if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
break;
case ICmpInst::ICMP_ULT:
if (Max.ult(RHSVal)) // A <u C -> true iff max(A) < C
if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Min.uge(RHSVal)) // A <u C -> false iff min(A) >= C
if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
if (RHSVal == Max) // A <u MAX -> A != MAX
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
if (RHSVal == Min+1) // A <u MIN+1 -> A == MIN
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
// (x <u 2147483648) -> (x >s -1) -> true if sign bit clear
if (CI->isMinValue(true))
return new ICmpInst(ICmpInst::ICMP_SGT, Op0,
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
if (Op1Max == Op0Min+1) // A <u C -> A == C-1 if min(A)+1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
// (x <u 2147483648) -> (x >s -1) -> true if sign bit clear
if (CI->isMinValue(true))
return new ICmpInst(ICmpInst::ICMP_SGT, Op0,
ConstantInt::getAllOnesValue(Op0->getType()));
}
break;
case ICmpInst::ICMP_UGT:
if (Min.ugt(RHSVal)) // A >u C -> true iff min(A) > C
if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Max.ule(RHSVal)) // A >u C -> false iff max(A) <= C
if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
if (RHSVal == Min) // A >u MIN -> A != MIN
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
if (RHSVal == Max-1) // A >u MAX-1 -> A == MAX
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
// (x >u 2147483647) -> (x <s 0) -> true if sign bit set
if (CI->isMaxValue(true))
return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
ConstantInt::getNullValue(Op0->getType()));
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
if (Op1Min == Op0Max-1) // A >u C -> A == C+1 if max(a)-1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
// (x >u 2147483647) -> (x <s 0) -> true if sign bit set
if (CI->isMaxValue(true))
return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
ConstantInt::getNullValue(Op0->getType()));
}
break;
case ICmpInst::ICMP_SLT:
if (Max.slt(RHSVal)) // A <s C -> true iff max(A) < C
if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Min.sge(RHSVal)) // A <s C -> false iff min(A) >= C
if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
if (RHSVal == Max) // A <s MAX -> A != MAX
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
if (RHSVal == Min+1) // A <s MIN+1 -> A == MIN
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
}
break;
case ICmpInst::ICMP_SGT:
if (Min.sgt(RHSVal)) // A >s C -> true iff min(A) > C
case ICmpInst::ICMP_SGT:
if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Max.sle(RHSVal)) // A >s C -> false iff max(A) <= C
if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
if (RHSVal == Min) // A >s MIN -> A != MIN
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
if (RHSVal == Max-1) // A >s MAX-1 -> A == MAX
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
}
break;
case ICmpInst::ICMP_SGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
case ICmpInst::ICMP_SLE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
case ICmpInst::ICMP_UGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
case ICmpInst::ICMP_ULE:
assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
}
// Turn a signed comparison into an unsigned one if both operands
// are known to have the same sign.
if (I.isSignedPredicate() &&
((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) ||
(Op0KnownOne.isNegative() && Op1KnownOne.isNegative())))
return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
}
// Test if the ICmpInst instruction is used exclusively by a select as

View File

@ -0,0 +1,28 @@
; RUN: llvm-as < %s | opt -instcombine | llvm-dis > %t
; RUN: not grep zext %t
; RUN: not grep slt %t
; RUN: grep {icmp ult} %t
; Instcombine should convert the zext+slt into a simple ult.
define void @foo(double* %p) nounwind {
entry:
br label %bb
bb:
%indvar = phi i64 [ 0, %entry ], [ %indvar.next, %bb ]
%t0 = and i64 %indvar, 65535
%t1 = getelementptr double* %p, i64 %t0
%t2 = load double* %t1, align 8
%t3 = mul double %t2, 2.2
store double %t3, double* %t1, align 8
%i.04 = trunc i64 %indvar to i16
%t4 = add i16 %i.04, 1
%t5 = zext i16 %t4 to i32
%t6 = icmp slt i32 %t5, 500
%indvar.next = add i64 %indvar, 1
br i1 %t6, label %bb, label %return
return:
ret void
}