[InstSimplify] refactor finding limits for icmp with binop; NFCI

llvm-svn: 292812
This commit is contained in:
Sanjay Patel 2017-01-23 18:22:26 +00:00
parent f523ed4b82
commit d945eb19e4

View File

@ -2377,6 +2377,149 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
return nullptr;
}
/// Many binary operators with a constant operand have an easy-to-compute
/// range of outputs. This can be used to fold a comparison to always true or
/// always false.
static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) {
unsigned Width = Lower.getBitWidth();
const APInt *C;
switch (BO.getOpcode()) {
case Instruction::Add:
if (BO.hasNoUnsignedWrap() && match(BO.getOperand(1), m_APInt(C)))
// 'add nuw x, C' produces [C, UINT_MAX].
Lower = *C;
break;
case Instruction::And:
if (match(BO.getOperand(1), m_APInt(C)))
// 'and x, C' produces [0, C].
Upper = *C + 1;
break;
case Instruction::Or:
if (match(BO.getOperand(1), m_APInt(C)))
// 'or x, C' produces [C, UINT_MAX].
Lower = *C;
break;
case Instruction::AShr:
if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
// 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C].
Lower = APInt::getSignedMinValue(Width).ashr(*C);
Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1;
} else if (match(BO.getOperand(0), m_APInt(C))) {
unsigned ShiftAmount = Width - 1;
if (*C != 0 && BO.isExact())
ShiftAmount = C->countTrailingZeros();
if (C->isNegative()) {
// 'ashr C, x' produces [C, C >> (Width-1)]
Lower = *C;
Upper = C->ashr(ShiftAmount) + 1;
} else {
// 'ashr C, x' produces [C >> (Width-1), C]
Lower = C->ashr(ShiftAmount);
Upper = *C + 1;
}
}
break;
case Instruction::LShr:
if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
// 'lshr x, C' produces [0, UINT_MAX >> C].
Upper = APInt::getAllOnesValue(Width).lshr(*C) + 1;
} else if (match(BO.getOperand(0), m_APInt(C))) {
// 'lshr C, x' produces [C >> (Width-1), C].
unsigned ShiftAmount = Width - 1;
if (*C != 0 && BO.isExact())
ShiftAmount = C->countTrailingZeros();
Lower = C->lshr(ShiftAmount);
Upper = *C + 1;
}
break;
case Instruction::Shl:
if (match(BO.getOperand(0), m_APInt(C))) {
if (BO.hasNoUnsignedWrap()) {
// 'shl nuw C, x' produces [C, C << CLZ(C)]
Lower = *C;
Upper = Lower.shl(Lower.countLeadingZeros()) + 1;
} else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw?
if (C->isNegative()) {
// 'shl nsw C, x' produces [C << CLO(C)-1, C]
unsigned ShiftAmount = C->countLeadingOnes() - 1;
Lower = C->shl(ShiftAmount);
Upper = *C + 1;
} else {
// 'shl nsw C, x' produces [C, C << CLZ(C)-1]
unsigned ShiftAmount = C->countLeadingZeros() - 1;
Lower = *C;
Upper = C->shl(ShiftAmount) + 1;
}
}
}
break;
case Instruction::SDiv:
if (match(BO.getOperand(1), m_APInt(C))) {
APInt IntMin = APInt::getSignedMinValue(Width);
APInt IntMax = APInt::getSignedMaxValue(Width);
if (C->isAllOnesValue()) {
// 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
// where C != -1 and C != 0 and C != 1
Lower = IntMin + 1;
Upper = IntMax + 1;
} else if (C->countLeadingZeros() < Width - 1) {
// 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C]
// where C != -1 and C != 0 and C != 1
Lower = IntMin.sdiv(*C);
Upper = IntMax.sdiv(*C);
if (Lower.sgt(Upper))
std::swap(Lower, Upper);
Upper = Upper + 1;
assert(Upper != Lower && "Upper part of range has wrapped!");
}
} else if (match(BO.getOperand(0), m_APInt(C))) {
if (C->isMinSignedValue()) {
// 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
Lower = *C;
Upper = Lower.lshr(1) + 1;
} else {
// 'sdiv C, x' produces [-|C|, |C|].
Upper = C->abs() + 1;
Lower = (-Upper) + 1;
}
}
break;
case Instruction::UDiv:
if (match(BO.getOperand(1), m_APInt(C)) && *C != 0) {
// 'udiv x, C' produces [0, UINT_MAX / C].
Upper = APInt::getMaxValue(Width).udiv(*C) + 1;
} else if (match(BO.getOperand(0), m_APInt(C))) {
// 'udiv C, x' produces [0, C].
Upper = *C + 1;
}
break;
case Instruction::SRem:
if (match(BO.getOperand(1), m_APInt(C))) {
// 'srem x, C' produces (-|C|, |C|).
Upper = C->abs();
Lower = (-Upper) + 1;
}
break;
case Instruction::URem:
if (match(BO.getOperand(1), m_APInt(C)))
// 'urem x, C' produces [0, C).
Upper = *C;
break;
default:
break;
}
}
static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
Value *RHS) {
const APInt *C;
@ -2390,114 +2533,12 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
if (RHS_CR.isFullSet())
return ConstantInt::getTrue(GetCompareTy(RHS));
// Many binary operators with constant RHS have easy to compute constant
// range. Use them to check whether the comparison is a tautology.
// Find the range of possible values for binary operators.
unsigned Width = C->getBitWidth();
APInt Lower = APInt(Width, 0);
APInt Upper = APInt(Width, 0);
const APInt *C2;
if (match(LHS, m_URem(m_Value(), m_APInt(C2)))) {
// 'urem x, C2' produces [0, C2).
Upper = *C2;
} else if (match(LHS, m_SRem(m_Value(), m_APInt(C2)))) {
// 'srem x, C2' produces (-|C2|, |C2|).
Upper = C2->abs();
Lower = (-Upper) + 1;
} else if (match(LHS, m_UDiv(m_APInt(C2), m_Value()))) {
// 'udiv C2, x' produces [0, C2].
Upper = *C2 + 1;
} else if (match(LHS, m_UDiv(m_Value(), m_APInt(C2)))) {
// 'udiv x, C2' produces [0, UINT_MAX / C2].
APInt NegOne = APInt::getAllOnesValue(Width);
if (*C2 != 0)
Upper = NegOne.udiv(*C2) + 1;
} else if (match(LHS, m_SDiv(m_APInt(C2), m_Value()))) {
if (C2->isMinSignedValue()) {
// 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
Lower = *C2;
Upper = Lower.lshr(1) + 1;
} else {
// 'sdiv C2, x' produces [-|C2|, |C2|].
Upper = C2->abs() + 1;
Lower = (-Upper) + 1;
}
} else if (match(LHS, m_SDiv(m_Value(), m_APInt(C2)))) {
APInt IntMin = APInt::getSignedMinValue(Width);
APInt IntMax = APInt::getSignedMaxValue(Width);
if (C2->isAllOnesValue()) {
// 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
// where C2 != -1 and C2 != 0 and C2 != 1
Lower = IntMin + 1;
Upper = IntMax + 1;
} else if (C2->countLeadingZeros() < Width - 1) {
// 'sdiv x, C2' produces [INT_MIN / C2, INT_MAX / C2]
// where C2 != -1 and C2 != 0 and C2 != 1
Lower = IntMin.sdiv(*C2);
Upper = IntMax.sdiv(*C2);
if (Lower.sgt(Upper))
std::swap(Lower, Upper);
Upper = Upper + 1;
assert(Upper != Lower && "Upper part of range has wrapped!");
}
} else if (match(LHS, m_NUWShl(m_APInt(C2), m_Value()))) {
// 'shl nuw C2, x' produces [C2, C2 << CLZ(C2)]
Lower = *C2;
Upper = Lower.shl(Lower.countLeadingZeros()) + 1;
} else if (match(LHS, m_NSWShl(m_APInt(C2), m_Value()))) {
if (C2->isNegative()) {
// 'shl nsw C2, x' produces [C2 << CLO(C2)-1, C2]
unsigned ShiftAmount = C2->countLeadingOnes() - 1;
Lower = C2->shl(ShiftAmount);
Upper = *C2 + 1;
} else {
// 'shl nsw C2, x' produces [C2, C2 << CLZ(C2)-1]
unsigned ShiftAmount = C2->countLeadingZeros() - 1;
Lower = *C2;
Upper = C2->shl(ShiftAmount) + 1;
}
} else if (match(LHS, m_LShr(m_Value(), m_APInt(C2)))) {
// 'lshr x, C2' produces [0, UINT_MAX >> C2].
APInt NegOne = APInt::getAllOnesValue(Width);
if (C2->ult(Width))
Upper = NegOne.lshr(*C2) + 1;
} else if (match(LHS, m_LShr(m_APInt(C2), m_Value()))) {
// 'lshr C2, x' produces [C2 >> (Width-1), C2].
unsigned ShiftAmount = Width - 1;
if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact())
ShiftAmount = C2->countTrailingZeros();
Lower = C2->lshr(ShiftAmount);
Upper = *C2 + 1;
} else if (match(LHS, m_AShr(m_Value(), m_APInt(C2)))) {
// 'ashr x, C2' produces [INT_MIN >> C2, INT_MAX >> C2].
APInt IntMin = APInt::getSignedMinValue(Width);
APInt IntMax = APInt::getSignedMaxValue(Width);
if (C2->ult(Width)) {
Lower = IntMin.ashr(*C2);
Upper = IntMax.ashr(*C2) + 1;
}
} else if (match(LHS, m_AShr(m_APInt(C2), m_Value()))) {
unsigned ShiftAmount = Width - 1;
if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact())
ShiftAmount = C2->countTrailingZeros();
if (C2->isNegative()) {
// 'ashr C2, x' produces [C2, C2 >> (Width-1)]
Lower = *C2;
Upper = C2->ashr(ShiftAmount) + 1;
} else {
// 'ashr C2, x' produces [C2 >> (Width-1), C2]
Lower = C2->ashr(ShiftAmount);
Upper = *C2 + 1;
}
} else if (match(LHS, m_Or(m_Value(), m_APInt(C2)))) {
// 'or x, C2' produces [C2, UINT_MAX].
Lower = *C2;
} else if (match(LHS, m_And(m_Value(), m_APInt(C2)))) {
// 'and x, C2' produces [0, C2].
Upper = *C2 + 1;
} else if (match(LHS, m_NUWAdd(m_Value(), m_APInt(C2)))) {
// 'add nuw x, C2' produces [C2, UINT_MAX].
Lower = *C2;
}
if (auto *BO = dyn_cast<BinaryOperator>(LHS))
setLimitsForBinOp(*BO, Lower, Upper);
ConstantRange LHS_CR =
Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true);