diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index f484445ec3c..42507e47c0c 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2230,27 +2230,29 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { return nullptr; } -/// Simplify icmp_eq and icmp_ne instructions with binary operator LHS and -/// integer scalar or splat vector constant RHS. -Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant( - ICmpInst &ICI, BinaryOperator *BO, const APInt *RHSV) { - // FIXME: Some of these folds could work with arbitrary constants, but this - // match is limited to scalars and vector splat constants. - if (!ICI.isEquality()) +/// Fold an icmp equality instruction with binary operator LHS and constant RHS: +/// icmp eq/ne BO, C. +Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt *C) { + // TODO: Some of these folds could work with arbitrary constants, but this + // function is limited to scalar and vector splat constants. + if (!Cmp.isEquality()) return nullptr; - Constant *RHS = cast(ICI.getOperand(1)); - bool isICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + ICmpInst::Predicate Pred = Cmp.getPredicate(); + bool isICMP_NE = Pred == ICmpInst::ICMP_NE; + Constant *RHS = cast(Cmp.getOperand(1)); Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); switch (BO->getOpcode()) { case Instruction::SRem: // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (*RHSV == 0 && BO->hasOneUse()) { + if (*C == 0 && BO->hasOneUse()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { Value *NewRem = Builder->CreateURem(BOp0, BOp1, BO->getName()); - return new ICmpInst(ICI.getPredicate(), NewRem, + return new ICmpInst(Pred, NewRem, Constant::getNullValue(BO->getType())); } } @@ -2261,19 +2263,19 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant( if (match(BOp1, m_APInt(BOC))) { if (BO->hasOneUse()) { Constant *SubC = ConstantExpr::getSub(RHS, cast(BOp1)); - return new ICmpInst(ICI.getPredicate(), BOp0, SubC); + return new ICmpInst(Pred, BOp0, SubC); } - } else if (*RHSV == 0) { + } else if (*C == 0) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. if (Value *NegVal = dyn_castNegVal(BOp1)) - return new ICmpInst(ICI.getPredicate(), BOp0, NegVal); + return new ICmpInst(Pred, BOp0, NegVal); if (Value *NegVal = dyn_castNegVal(BOp0)) - return new ICmpInst(ICI.getPredicate(), NegVal, BOp1); + return new ICmpInst(Pred, NegVal, BOp1); if (BO->hasOneUse()) { Value *Neg = Builder->CreateNeg(BOp1); Neg->takeName(BO); - return new ICmpInst(ICI.getPredicate(), BOp0, Neg); + return new ICmpInst(Pred, BOp0, Neg); } } break; @@ -2283,11 +2285,10 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant( if (Constant *BOC = dyn_cast(BOp1)) { // For the xor case, we can xor two constants together, eliminating // the explicit xor. - return new ICmpInst(ICI.getPredicate(), BOp0, - ConstantExpr::getXor(RHS, BOC)); - } else if (*RHSV == 0) { + return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); + } else if (*C == 0) { // Replace ((xor A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BOp0, BOp1); + return new ICmpInst(Pred, BOp0, BOp1); } } break; @@ -2297,10 +2298,10 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant( if (match(BOp0, m_APInt(BOC))) { // Replace ((sub A, B) != C) with (B != A-C) if A & C are constants. Constant *SubC = ConstantExpr::getSub(cast(BOp0), RHS); - return new ICmpInst(ICI.getPredicate(), BOp1, SubC); - } else if (*RHSV == 0) { + return new ICmpInst(Pred, BOp1, SubC); + } else if (*C == 0) { // Replace ((sub A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BOp0, BOp1); + return new ICmpInst(Pred, BOp0, BOp1); } } break; @@ -2312,7 +2313,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant( // This removes the -1 constant. Constant *NotBOC = ConstantExpr::getNot(cast(BOp1)); Value *And = Builder->CreateAnd(BOp0, NotBOC); - return new ICmpInst(ICI.getPredicate(), And, NotBOC); + return new ICmpInst(Pred, And, NotBOC); } break; } @@ -2320,7 +2321,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant( const APInt *BOC; if (match(BOp1, m_APInt(BOC))) { // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (RHSV == BOC && RHSV->isPowerOf2()) + if (C == BOC && C->isPowerOf2()) return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, BO, Constant::getNullValue(RHS->getType())); @@ -2331,39 +2332,35 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant( // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 if (BOC->isSignBit()) { Constant *Zero = Constant::getNullValue(BOp0->getType()); - ICmpInst::Predicate Pred = - isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; - return new ICmpInst(Pred, BOp0, Zero); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(NewPred, BOp0, Zero); } // ((X & ~7) == 0) --> X < 8 - if (*RHSV == 0 && (~(*BOC) + 1).isPowerOf2()) { + if (*C == 0 && (~(*BOC) + 1).isPowerOf2()) { Constant *NegBOC = ConstantExpr::getNeg(cast(BOp1)); - ICmpInst::Predicate Pred = - isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; - return new ICmpInst(Pred, BOp0, NegBOC); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(NewPred, BOp0, NegBOC); } } break; } case Instruction::Mul: - if (*RHSV == 0 && BO->hasNoSignedWrap()) { + if (*C == 0 && BO->hasNoSignedWrap()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && *BOC != 0) { // The trivial case (mul X, 0) is handled by InstSimplify. // General case : (mul X, C) != 0 iff X != 0 // (mul X, C) == 0 iff X == 0 - return new ICmpInst(ICI.getPredicate(), BOp0, - Constant::getNullValue(RHS->getType())); + return new ICmpInst(Pred, BOp0, Constant::getNullValue(RHS->getType())); } } break; case Instruction::UDiv: - if (*RHSV == 0) { + if (*C == 0) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) - ICmpInst::Predicate Pred = - isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; - return new ICmpInst(Pred, BOp1, BOp0); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return new ICmpInst(NewPred, BOp1, BOp0); } break; default: @@ -2372,44 +2369,44 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant( return nullptr; } -Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &ICI, - const APInt *Op1C) { - IntrinsicInst *II = dyn_cast(ICI.getOperand(0)); - if (!II || !ICI.isEquality()) +/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. +Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, + const APInt *C) { + IntrinsicInst *II = dyn_cast(Cmp.getOperand(0)); + if (!II || !Cmp.isEquality()) return nullptr; // Handle icmp {eq|ne} , intcst. switch (II->getIntrinsicID()) { case Intrinsic::bswap: Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, Builder->getInt(Op1C->byteSwap())); - return &ICI; + Cmp.setOperand(0, II->getArgOperand(0)); + Cmp.setOperand(1, Builder->getInt(C->byteSwap())); + return &Cmp; case Intrinsic::ctlz: case Intrinsic::cttz: // ctz(A) == bitwidth(A) -> A == 0 and likewise for != - if (*Op1C == Op1C->getBitWidth()) { + if (*C == C->getBitWidth()) { Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, ConstantInt::getNullValue(II->getType())); - return &ICI; + Cmp.setOperand(0, II->getArgOperand(0)); + Cmp.setOperand(1, ConstantInt::getNullValue(II->getType())); + return &Cmp; } break; case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != - bool IsZero = *Op1C == 0; - if (IsZero || *Op1C == Op1C->getBitWidth()) { + bool IsZero = *C == 0; + if (IsZero || *C == C->getBitWidth()) { Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - auto *NewOp = IsZero - ? ConstantInt::getNullValue(II->getType()) - : ConstantInt::getAllOnesValue(II->getType()); - ICI.setOperand(1, NewOp); - return &ICI; - } + Cmp.setOperand(0, II->getArgOperand(0)); + auto *NewOp = IsZero ? Constant::getNullValue(II->getType()) + : Constant::getAllOnesValue(II->getType()); + Cmp.setOperand(1, NewOp); + return &Cmp; } break; + } default: break; }