diff --git a/include/llvm/Analysis/ValueTracking.h b/include/llvm/Analysis/ValueTracking.h index 2b3176dca25..6df1693c78e 100644 --- a/include/llvm/Analysis/ValueTracking.h +++ b/include/llvm/Analysis/ValueTracking.h @@ -39,6 +39,23 @@ namespace llvm { APInt &KnownOne, const TargetData *TD = 0, unsigned Depth = 0); + /// ComputeSignBit - Determine whether the sign bit is known to be zero or + /// one. Convenience wrapper around ComputeMaskedBits. + void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const TargetData *TD = 0, unsigned Depth = 0); + + /// isPowerOfTwo - Return true if the given value is known to have exactly one + /// bit set when defined. For vectors return true if every element is known to + /// be a power of two when defined. Supports values with integer or pointer + /// type and vectors of integers. + bool isPowerOfTwo(Value *V, const TargetData *TD = 0, unsigned Depth = 0); + + /// isKnownNonZero - Return true if the given value is known to be non-zero + /// when defined. For vectors return true if every element is known to be + /// non-zero when defined. Supports values with integer or pointer type and + /// vectors of integers. + bool isKnownNonZero(Value *V, const TargetData *TD = 0, unsigned Depth = 0); + /// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero. We use /// this predicate to simplify operations downstream. Mask is known to be /// zero for bits that V cannot have. diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 0cbaea3f93e..833f6caf0ee 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Support/PatternMatch.h" #include "llvm/Support/ValueHandle.h" #include "llvm/Target/TargetData.h" @@ -1153,7 +1154,69 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } - // See if we are doing a comparison with a constant. + // icmp , - Different stack variables have + // different addresses, and what's more the address of a stack variable is + // never null or equal to the address of a global. Note that generalizing + // to the case where LHS is a global variable address or null is pointless, + // since if both LHS and RHS are constants then we already constant folded + // the compare, and if only one of them is then we moved it to RHS already. + if (isa(LHS) && (isa(RHS) || isa(RHS) || + isa(RHS))) + // We already know that LHS != LHS. + return ConstantInt::get(ITy, CmpInst::isFalseWhenEqual(Pred)); + + // If we are comparing with zero then try hard since this is a common case. + if (match(RHS, m_Zero())) { + bool LHSKnownNonNegative, LHSKnownNegative; + switch (Pred) { + default: + assert(false && "Unknown ICmp predicate!"); + case ICmpInst::ICMP_ULT: + return ConstantInt::getFalse(LHS->getContext()); + case ICmpInst::ICMP_UGE: + return ConstantInt::getTrue(LHS->getContext()); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_ULE: + if (isKnownNonZero(LHS, TD)) + return ConstantInt::getFalse(LHS->getContext()); + break; + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + if (isKnownNonZero(LHS, TD)) + return ConstantInt::getTrue(LHS->getContext()); + break; + case ICmpInst::ICMP_SLT: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD); + if (LHSKnownNegative) + return ConstantInt::getTrue(LHS->getContext()); + if (LHSKnownNonNegative) + return ConstantInt::getFalse(LHS->getContext()); + break; + case ICmpInst::ICMP_SLE: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD); + if (LHSKnownNegative) + return ConstantInt::getTrue(LHS->getContext()); + if (LHSKnownNonNegative && isKnownNonZero(LHS, TD)) + return ConstantInt::getFalse(LHS->getContext()); + break; + case ICmpInst::ICMP_SGE: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD); + if (LHSKnownNegative) + return ConstantInt::getFalse(LHS->getContext()); + if (LHSKnownNonNegative) + return ConstantInt::getTrue(LHS->getContext()); + break; + case ICmpInst::ICMP_SGT: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD); + if (LHSKnownNegative) + return ConstantInt::getFalse(LHS->getContext()); + if (LHSKnownNonNegative && isKnownNonZero(LHS, TD)) + return ConstantInt::getTrue(LHS->getContext()); + break; + } + } + + // See if we are doing a comparison with a constant integer. if (ConstantInt *CI = dyn_cast(RHS)) { switch (Pred) { default: break; @@ -1192,17 +1255,6 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } - // icmp , - Different stack variables have - // different addresses, and what's more the address of a stack variable is - // never null or equal to the address of a global. Note that generalizing - // to the case where LHS is a global variable address or null is pointless, - // since if both LHS and RHS are constants then we already constant folded - // the compare, and if only one of them is then we moved it to RHS already. - if (isa(LHS) && (isa(RHS) || isa(RHS) || - isa(RHS))) - // We already know that LHS != LHS. - return ConstantInt::get(ITy, CmpInst::isFalseWhenEqual(Pred)); - // Compare of cast, for example (zext X) != 0 -> X != 0 if (isa(LHS) && (isa(RHS) || isa(RHS))) { Instruction *LI = cast(LHS); diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index c982e3e00de..c8a3ae893f5 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -24,9 +24,22 @@ #include "llvm/Target/TargetData.h" #include "llvm/Support/GetElementPtrTypeIterator.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/PatternMatch.h" #include "llvm/ADT/SmallPtrSet.h" #include using namespace llvm; +using namespace llvm::PatternMatch; + +const unsigned MaxDepth = 6; + +/// getBitWidth - Returns the bitwidth of the given scalar or pointer type (if +/// unknown returns 0). For vector types, returns the element type's bitwidth. +static unsigned getBitWidth(const Type *Ty, const TargetData *TD) { + if (unsigned BitWidth = Ty->getScalarSizeInBits()) + return BitWidth; + assert(isa(Ty) && "Expected a pointer type!"); + return TD ? TD->getPointerSizeInBits() : 0; +} /// ComputeMaskedBits - Determine which of the bits specified in Mask are /// known to be either zero or one and return them in the KnownZero/KnownOne @@ -47,7 +60,6 @@ using namespace llvm; void llvm::ComputeMaskedBits(Value *V, const APInt &Mask, APInt &KnownZero, APInt &KnownOne, const TargetData *TD, unsigned Depth) { - const unsigned MaxDepth = 6; assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); unsigned BitWidth = Mask.getBitWidth(); @@ -620,6 +632,157 @@ void llvm::ComputeMaskedBits(Value *V, const APInt &Mask, } } +/// ComputeSignBit - Determine whether the sign bit is known to be zero or +/// one. Convenience wrapper around ComputeMaskedBits. +void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const TargetData *TD, unsigned Depth) { + unsigned BitWidth = getBitWidth(V->getType(), TD); + if (!BitWidth) { + KnownZero = false; + KnownOne = false; + return; + } + APInt ZeroBits(BitWidth, 0); + APInt OneBits(BitWidth, 0); + ComputeMaskedBits(V, APInt::getSignBit(BitWidth), ZeroBits, OneBits, TD, + Depth); + KnownOne = OneBits[BitWidth - 1]; + KnownZero = ZeroBits[BitWidth - 1]; +} + +/// isPowerOfTwo - Return true if the given value is known to have exactly one +/// bit set when defined. For vectors return true if every element is known to +/// be a power of two when defined. Supports values with integer or pointer +/// types and vectors of integers. +bool llvm::isPowerOfTwo(Value *V, const TargetData *TD, unsigned Depth) { + if (ConstantInt *CI = dyn_cast(V)) + return CI->getValue().countPopulation() == 1; + // TODO: Handle vector constants. + + // 1 << X is clearly a power of two if the one is not shifted off the end. If + // it is shifted off the end then the result is undefined. + if (match(V, m_Shl(m_One(), m_Value()))) + return true; + + // (signbit) >>l X is clearly a power of two if the one is not shifted off the + // bottom. If it is shifted off the bottom then the result is undefined. + ConstantInt *CI; + if (match(V, m_LShr(m_ConstantInt(CI), m_Value())) && + CI->getValue().isSignBit()) + return true; + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth++ == MaxDepth) + return false; + + if (ZExtInst *ZI = dyn_cast(V)) + return isPowerOfTwo(ZI->getOperand(0), TD, Depth); + + if (SelectInst *SI = dyn_cast(V)) + return isPowerOfTwo(SI->getTrueValue(), TD, Depth) && + isPowerOfTwo(SI->getFalseValue(), TD, Depth); + + return false; +} + +/// isKnownNonZero - Return true if the given value is known to be non-zero +/// when defined. For vectors return true if every element is known to be +/// non-zero when defined. Supports values with integer or pointer type and +/// vectors of integers. +bool llvm::isKnownNonZero(Value *V, const TargetData *TD, unsigned Depth) { + if (Constant *C = dyn_cast(V)) { + if (C->isNullValue()) + return false; + if (isa(C)) + // Must be non-zero due to null test above. + return true; + // TODO: Handle vectors + return false; + } + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth++ == MaxDepth) + return false; + + unsigned BitWidth = getBitWidth(V->getType(), TD); + + // X | Y != 0 if X != 0 or Y != 0. + Value *X = 0, *Y = 0; + if (match(V, m_Or(m_Value(X), m_Value(Y)))) + return isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth); + + // ext X != 0 if X != 0. + if (isa(V) || isa(V)) + return isKnownNonZero(cast(V)->getOperand(0), TD, Depth); + + // shl X, A != 0 if X is odd. Note that the value of the shift is undefined + // if the lowest bit is shifted off the end. + if (BitWidth && match(V, m_Shl(m_Value(X), m_Value(Y)))) { + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + ComputeMaskedBits(V, APInt(BitWidth, 1), KnownZero, KnownOne, TD, Depth); + if (KnownOne[0]) + return true; + } + // shr X, A != 0 if X is negative. Note that the value of the shift is not + // defined if the sign bit is shifted off the end. + else if (match(V, m_Shr(m_Value(X), m_Value(Y)))) { + bool XKnownNonNegative, XKnownNegative; + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth); + if (XKnownNegative) + return true; + } + // X + Y. + else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { + bool XKnownNonNegative, XKnownNegative; + bool YKnownNonNegative, YKnownNegative; + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth); + ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, TD, Depth); + + // If X and Y are both non-negative (as signed values) then their sum is not + // zero. + if (XKnownNonNegative && YKnownNonNegative) + return true; + + // If X and Y are both negative (as signed values) then their sum is not + // zero unless both X and Y equal INT_MIN. + if (BitWidth && XKnownNegative && YKnownNegative) { + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + APInt Mask = APInt::getSignedMaxValue(BitWidth); + // The sign bit of X is set. If some other bit is set then X is not equal + // to INT_MIN. + ComputeMaskedBits(X, Mask, KnownZero, KnownOne, TD, Depth); + if ((KnownOne & Mask) != 0) + return true; + // The sign bit of Y is set. If some other bit is set then Y is not equal + // to INT_MIN. + ComputeMaskedBits(Y, Mask, KnownZero, KnownOne, TD, Depth); + if ((KnownOne & Mask) != 0) + return true; + } + + // The sum of a non-negative number and a power of two is not zero. + if (XKnownNonNegative && isPowerOfTwo(Y, TD, Depth)) + return true; + if (YKnownNonNegative && isPowerOfTwo(X, TD, Depth)) + return true; + } + // (C ? X : Y) != 0 if X != 0 and Y != 0. + else if (SelectInst *SI = dyn_cast(V)) { + if (isKnownNonZero(SI->getTrueValue(), TD, Depth) && + isKnownNonZero(SI->getFalseValue(), TD, Depth)) + return true; + } + + if (!BitWidth) return false; + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + ComputeMaskedBits(V, APInt::getAllOnesValue(BitWidth), KnownZero, KnownOne, + TD, Depth); + return KnownOne != 0; +} + /// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero. We use /// this predicate to simplify operations downstream. Mask is known to be zero /// for bits that V cannot have. diff --git a/test/Transforms/InstSimplify/2011-01-18-Compare.ll b/test/Transforms/InstSimplify/2011-01-18-Compare.ll index 127957942bc..78e225a8d1c 100644 --- a/test/Transforms/InstSimplify/2011-01-18-Compare.ll +++ b/test/Transforms/InstSimplify/2011-01-18-Compare.ll @@ -27,6 +27,14 @@ define i1 @zext2(i1 %x) { ; CHECK: ret i1 %x } +define i1 @zext3() { +; CHECK: @zext3 + %e = zext i1 1 to i32 + %c = icmp ne i32 %e, 0 + ret i1 %c +; CHECK: ret i1 true +} + define i1 @sext(i32 %x) { ; CHECK: @sext %e1 = sext i32 %x to i64 @@ -43,3 +51,49 @@ define i1 @sext2(i1 %x) { ret i1 %c ; CHECK: ret i1 %x } + +define i1 @sext3() { +; CHECK: @sext3 + %e = sext i1 1 to i32 + %c = icmp ne i32 %e, 0 + ret i1 %c +; CHECK: ret i1 true +} + +define i1 @add(i32 %x, i32 %y) { +; CHECK: @add + %l = lshr i32 %x, 1 + %r = lshr i32 %y, 1 + %s = add i32 %l, %r + %c = icmp eq i32 %s, 0 + ret i1 %c +; CHECK: ret i1 false +} + +define i1 @add2(i8 %x, i8 %y) { +; CHECK: @add2 + %l = or i8 %x, 128 + %r = or i8 %y, 129 + %s = add i8 %l, %r + %c = icmp eq i8 %s, 0 + ret i1 %c +; CHECK: ret i1 false +} + +define i1 @addpowtwo(i32 %x, i32 %y) { +; CHECK: @addpowtwo + %l = lshr i32 %x, 1 + %r = shl i32 1, %y + %s = add i32 %l, %r + %c = icmp eq i32 %s, 0 + ret i1 %c +; CHECK: ret i1 false +} + +define i1 @or(i32 %x) { +; CHECK: @or + %o = or i32 %x, 1 + %c = icmp eq i32 %o, 0 + ret i1 %c +; CHECK: ret i1 false +}