From a7fae1f86531a4d2206f63d8fa6c34de7e1d69fa Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Sat, 25 Apr 2009 17:12:48 +0000 Subject: [PATCH] Add several more icmp simplifications. Transform signed comparisons into unsigned ones when the operands are known to have the same sign bit value. llvm-svn: 70053 --- .../Scalar/InstructionCombining.cpp | 241 ++++++++++++------ .../InstCombine/signed-comparison.ll | 28 ++ 2 files changed, 187 insertions(+), 82 deletions(-) create mode 100644 test/Transforms/InstCombine/signed-comparison.ll diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index a2658b3e3f1..c8cdc4c9fcd 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -708,15 +708,13 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, // set of known zero and one bits, compute the maximum and minimum values that // could have the specified known zero and known one bits, returning them in // min/max. -static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty, - const APInt& KnownZero, +static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero, const APInt& KnownOne, APInt& Min, APInt& Max) { - uint32_t BitWidth = cast(Ty)->getBitWidth(); - assert(KnownZero.getBitWidth() == BitWidth && - KnownOne.getBitWidth() == BitWidth && - Min.getBitWidth() == BitWidth && Max.getBitWidth() == BitWidth && - "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth."); + assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && + KnownZero.getBitWidth() == Min.getBitWidth() && + KnownZero.getBitWidth() == Max.getBitWidth() && + "KnownZero, KnownOne and Min, Max must have equal bitwidth."); APInt UnknownBits = ~(KnownZero|KnownOne); // The minimum value is when all unknown bits are zeros, EXCEPT for the sign @@ -724,9 +722,9 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty, Min = KnownOne; Max = KnownOne|UnknownBits; - if (UnknownBits[BitWidth-1]) { // Sign bit is unknown - Min.set(BitWidth-1); - Max.clear(BitWidth-1); + if (UnknownBits.isNegative()) { // Sign bit is unknown + Min.set(Min.getBitWidth()-1); + Max.clear(Max.getBitWidth()-1); } } @@ -734,14 +732,12 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty, // a set of known zero and one bits, compute the maximum and minimum values that // could have the specified known zero and known one bits, returning them in // min/max. -static void ComputeUnsignedMinMaxValuesFromKnownBits(const Type *Ty, - const APInt &KnownZero, +static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { - uint32_t BitWidth = cast(Ty)->getBitWidth(); BitWidth = BitWidth; - assert(KnownZero.getBitWidth() == BitWidth && - KnownOne.getBitWidth() == BitWidth && - Min.getBitWidth() == BitWidth && Max.getBitWidth() && + assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && + KnownZero.getBitWidth() == Min.getBitWidth() && + KnownZero.getBitWidth() == Max.getBitWidth() && "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth."); APInt UnknownBits = ~(KnownZero|KnownOne); @@ -808,9 +804,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(V != 0 && "Null pointer of Value???"); assert(Depth <= 6 && "Limit Search Depth"); uint32_t BitWidth = DemandedMask.getBitWidth(); - const IntegerType *VTy = cast(V->getType()); - assert(VTy->getBitWidth() == BitWidth && - KnownZero.getBitWidth() == BitWidth && + const Type *VTy = V->getType(); + assert((TD || !isa(VTy)) && + "SimplifyDemandedBits needs to know bit widths!"); + assert((!TD || TD->getTypeSizeInBits(VTy) == BitWidth) && + (!isa(VTy) || + VTy->getPrimitiveSizeInBits() == BitWidth) && + KnownZero.getBitWidth() == BitWidth && KnownOne.getBitWidth() == BitWidth && "Value *V, DemandedMask, KnownZero and KnownOne \ must have same BitWidth"); @@ -820,7 +820,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownZero = ~KnownOne & DemandedMask; return 0; } - + if (isa(V)) { + // We know all of the bits for a constant! + KnownOne.clear(); + KnownZero = DemandedMask; + return 0; + } + KnownZero.clear(); KnownOne.clear(); if (DemandedMask == 0) { // Not demanding any bits from V. @@ -832,12 +838,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (Depth == 6) // Limit search depth. return 0; - Instruction *I = dyn_cast(V); - if (!I) return 0; // Only analyze instructions. - APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); APInt &RHSKnownZero = KnownZero, &RHSKnownOne = KnownOne; + Instruction *I = dyn_cast(V); + if (!I) { + ComputeMaskedBits(V, DemandedMask, RHSKnownZero, RHSKnownOne, Depth); + return 0; // Only analyze instructions. + } + // If there are multiple uses of this value and we aren't at the root, then // we can't do any simplifications of the operands, because DemandedMask // only reflects the bits demanded by *one* of the users. @@ -1399,8 +1408,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the client is only demanding bits that we know, return the known // constant. - if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) - return ConstantInt::get(RHSKnownOne); + if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) { + Constant *C = ConstantInt::get(RHSKnownOne); + if (isa(V->getType())) + C = ConstantExpr::getIntToPtr(C, V->getType()); + return C; + } return false; } @@ -5831,6 +5844,14 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } + unsigned BitWidth = 0; + if (TD) + BitWidth = TD->getTypeSizeInBits(Ty); + else if (isa(Ty)) + BitWidth = Ty->getPrimitiveSizeInBits(); + + bool isSignBit = false; + // See if we are doing a comparison with a constant. if (ConstantInt *CI = dyn_cast(Op1)) { Value *A = 0, *B = 0; @@ -5865,105 +5886,161 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI)); } - // See if we can fold the comparison based on range information we can get - // by checking whether bits are known to be zero or one in the input. - uint32_t BitWidth = cast(Ty)->getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - // If this comparison is a normal comparison, it demands all // bits, if it is a sign bit comparison, it only demands the sign bit. bool UnusedBit; - bool isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); - - if (SimplifyDemandedBits(I.getOperandUse(0), + isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); + } + + // See if we can fold the comparison based on range information we can get + // by checking whether bits are known to be zero or one in the input. + if (BitWidth != 0) { + APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); + APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); + + if (SimplifyDemandedBits(I.getOperandUse(0), isSignBit ? APInt::getSignBit(BitWidth) : APInt::getAllOnesValue(BitWidth), - KnownZero, KnownOne, 0)) + Op0KnownZero, Op0KnownOne, 0)) return &I; - + if (SimplifyDemandedBits(I.getOperandUse(1), + APInt::getAllOnesValue(BitWidth), + Op1KnownZero, Op1KnownOne, 0)) + return &I; + // Given the known and unknown bits, compute a range that the LHS could be // in. Compute the Min, Max and RHS values based on the known bits. For the // EQ and NE we use unsigned values. - APInt Min(BitWidth, 0), Max(BitWidth, 0); - if (ICmpInst::isSignedPredicate(I.getPredicate())) - ComputeSignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne, Min, Max); - else - ComputeUnsignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne,Min,Max); - + APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); + APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); + if (ICmpInst::isSignedPredicate(I.getPredicate())) { + ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, + Op0Min, Op0Max); + ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, + Op1Min, Op1Max); + } else { + ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, + Op0Min, Op0Max); + ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, + Op1Min, Op1Max); + } + // If Min and Max are known to be the same, then SimplifyDemandedBits // figured out that the LHS is a constant. Just constant fold this now so // that code below can assume that Min != Max. - if (Min == Max) - return ReplaceInstUsesWith(I, ConstantExpr::getICmp(I.getPredicate(), - ConstantInt::get(Min), - CI)); - + if (!isa(Op0) && Op0Min == Op0Max) + return new ICmpInst(I.getPredicate(), ConstantInt::get(Op0Min), Op1); + if (!isa(Op1) && Op1Min == Op1Max) + return new ICmpInst(I.getPredicate(), Op0, ConstantInt::get(Op1Min)); + // Based on the range information we know about the LHS, see if we can // simplify this comparison. For example, (x&4) < 8 is always true. - const APInt &RHSVal = CI->getValue(); - switch (I.getPredicate()) { // LE/GE have been folded already. + switch (I.getPredicate()) { default: assert(0 && "Unknown icmp opcode!"); case ICmpInst::ICMP_EQ: - if (Max.ult(RHSVal) || Min.ugt(RHSVal)) + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) return ReplaceInstUsesWith(I, ConstantInt::getFalse()); break; case ICmpInst::ICMP_NE: - if (Max.ult(RHSVal) || Min.ugt(RHSVal)) + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) return ReplaceInstUsesWith(I, ConstantInt::getTrue()); break; case ICmpInst::ICMP_ULT: - if (Max.ult(RHSVal)) // A true iff max(A) < C + if (Op0Max.ult(Op1Min)) // A true if max(A) < min(B) return ReplaceInstUsesWith(I, ConstantInt::getTrue()); - if (Min.uge(RHSVal)) // A false iff min(A) >= C + if (Op0Min.uge(Op1Max)) // A false if min(A) >= max(B) return ReplaceInstUsesWith(I, ConstantInt::getFalse()); - if (RHSVal == Max) // A A != MAX + if (Op1Min == Op0Max) // A A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (RHSVal == Min+1) // A A == MIN - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); - - // (x (x >s -1) -> true if sign bit clear - if (CI->isMinValue(true)) - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, + if (ConstantInt *CI = dyn_cast(Op1)) { + if (Op1Max == Op0Min+1) // A A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); + + // (x (x >s -1) -> true if sign bit clear + if (CI->isMinValue(true)) + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, ConstantInt::getAllOnesValue(Op0->getType())); + } break; case ICmpInst::ICMP_UGT: - if (Min.ugt(RHSVal)) // A >u C -> true iff min(A) > C + if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) return ReplaceInstUsesWith(I, ConstantInt::getTrue()); - if (Max.ule(RHSVal)) // A >u C -> false iff max(A) <= C + if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) return ReplaceInstUsesWith(I, ConstantInt::getFalse()); - - if (RHSVal == Min) // A >u MIN -> A != MIN + + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (RHSVal == Max-1) // A >u MAX-1 -> A == MAX - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); - - // (x >u 2147483647) -> (x true if sign bit set - if (CI->isMaxValue(true)) - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, - ConstantInt::getNullValue(Op0->getType())); + if (ConstantInt *CI = dyn_cast(Op1)) { + if (Op1Min == Op0Max-1) // A >u C -> A == C+1 if max(a)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); + + // (x >u 2147483647) -> (x true if sign bit set + if (CI->isMaxValue(true)) + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, + ConstantInt::getNullValue(Op0->getType())); + } break; case ICmpInst::ICMP_SLT: - if (Max.slt(RHSVal)) // A true iff max(A) < C + if (Op0Max.slt(Op1Min)) // A true if max(A) < min(C) return ReplaceInstUsesWith(I, ConstantInt::getTrue()); - if (Min.sge(RHSVal)) // A false iff min(A) >= C + if (Op0Min.sge(Op1Max)) // A false if min(A) >= max(C) return ReplaceInstUsesWith(I, ConstantInt::getFalse()); - if (RHSVal == Max) // A A != MAX + if (Op1Min == Op0Max) // A A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (RHSVal == Min+1) // A A == MIN - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); + if (ConstantInt *CI = dyn_cast(Op1)) { + if (Op1Max == Op0Min+1) // A A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); + } break; - case ICmpInst::ICMP_SGT: - if (Min.sgt(RHSVal)) // A >s C -> true iff min(A) > C + case ICmpInst::ICMP_SGT: + if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) return ReplaceInstUsesWith(I, ConstantInt::getTrue()); - if (Max.sle(RHSVal)) // A >s C -> false iff max(A) <= C + if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) return ReplaceInstUsesWith(I, ConstantInt::getFalse()); - - if (RHSVal == Min) // A >s MIN -> A != MIN + + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (RHSVal == Max-1) // A >s MAX-1 -> A == MAX - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); + if (ConstantInt *CI = dyn_cast(Op1)) { + if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); + } + break; + case ICmpInst::ICMP_SGE: + assert(!isa(Op1) && "ICMP_SGE with ConstantInt not folded!"); + if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_SLE: + assert(!isa(Op1) && "ICMP_SLE with ConstantInt not folded!"); + if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_UGE: + assert(!isa(Op1) && "ICMP_UGE with ConstantInt not folded!"); + if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_ULE: + assert(!isa(Op1) && "ICMP_ULE with ConstantInt not folded!"); + if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); break; } + + // Turn a signed comparison into an unsigned one if both operands + // are known to have the same sign. + if (I.isSignedPredicate() && + ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || + (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) + return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); } // Test if the ICmpInst instruction is used exclusively by a select as diff --git a/test/Transforms/InstCombine/signed-comparison.ll b/test/Transforms/InstCombine/signed-comparison.ll new file mode 100644 index 00000000000..fdf150f9c61 --- /dev/null +++ b/test/Transforms/InstCombine/signed-comparison.ll @@ -0,0 +1,28 @@ +; RUN: llvm-as < %s | opt -instcombine | llvm-dis > %t +; RUN: not grep zext %t +; RUN: not grep slt %t +; RUN: grep {icmp ult} %t + +; Instcombine should convert the zext+slt into a simple ult. + +define void @foo(double* %p) nounwind { +entry: + br label %bb + +bb: + %indvar = phi i64 [ 0, %entry ], [ %indvar.next, %bb ] + %t0 = and i64 %indvar, 65535 + %t1 = getelementptr double* %p, i64 %t0 + %t2 = load double* %t1, align 8 + %t3 = mul double %t2, 2.2 + store double %t3, double* %t1, align 8 + %i.04 = trunc i64 %indvar to i16 + %t4 = add i16 %i.04, 1 + %t5 = zext i16 %t4 to i32 + %t6 = icmp slt i32 %t5, 500 + %indvar.next = add i64 %indvar, 1 + br i1 %t6, label %bb, label %return + +return: + ret void +}