diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index b4d3e625ae0..1a5655b0966 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -496,6 +496,38 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, return result; } +/// decomposeBitTestICmp - Decompose an icmp into the form ((X & Y) pred Z) +/// if possible. The returned predicate is either == or !=. Returns false if +/// decomposition fails. +static bool decomposeBitTestICmp(const ICmpInst *I, ICmpInst::Predicate &Pred, + Value *&X, Value *&Y, Value *&Z) { + // X < 0 is equivalent to (X & SignBit) != 0. + if (I->getPredicate() == ICmpInst::ICMP_SLT) + if (ConstantInt *C = dyn_cast(I->getOperand(1))) + if (C->isZero()) { + X = I->getOperand(0); + Y = ConstantInt::get(I->getContext(), + APInt::getSignBit(C->getBitWidth())); + Pred = ICmpInst::ICMP_NE; + Z = C; + return true; + } + + // X > -1 is equivalent to (X & SignBit) == 0. + if (I->getPredicate() == ICmpInst::ICMP_SGT) + if (ConstantInt *C = dyn_cast(I->getOperand(1))) + if (C->isAllOnesValue()) { + X = I->getOperand(0); + Y = ConstantInt::get(I->getContext(), + APInt::getSignBit(C->getBitWidth())); + Pred = ICmpInst::ICMP_EQ; + Z = ConstantInt::getNullValue(C->getType()); + return true; + } + + return false; +} + /// foldLogOpOfMaskedICmpsHelper: /// handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// return the set of pattern classes (from MaskedICmpType) @@ -503,10 +535,9 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, Value*& B, Value*& C, Value*& D, Value*& E, - ICmpInst *LHS, ICmpInst *RHS) { - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); - if (LHSCC != ICmpInst::ICMP_EQ && LHSCC != ICmpInst::ICMP_NE) return 0; - if (RHSCC != ICmpInst::ICMP_EQ && RHSCC != ICmpInst::ICMP_NE) return 0; + ICmpInst *LHS, ICmpInst *RHS, + ICmpInst::Predicate &LHSCC, + ICmpInst::Predicate &RHSCC) { if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0; // vectors are not (yet?) supported if (LHS->getOperand(0)->getType()->isVectorTy()) return 0; @@ -520,40 +551,60 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, Value *L1 = LHS->getOperand(0); Value *L2 = LHS->getOperand(1); Value *L11,*L12,*L21,*L22; - if (match(L1, m_And(m_Value(L11), m_Value(L12)))) { - if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) + // Check whether the icmp can be decomposed into a bit test. + if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) { + L21 = L22 = L1 = 0; + } else { + // Look for ANDs in the LHS icmp. + if (match(L1, m_And(m_Value(L11), m_Value(L12)))) { + if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) + L21 = L22 = 0; + } else { + if (!match(L2, m_And(m_Value(L11), m_Value(L12)))) + return 0; + std::swap(L1, L2); L21 = L22 = 0; + } } - else { - if (!match(L2, m_And(m_Value(L11), m_Value(L12)))) - return 0; - std::swap(L1, L2); - L21 = L22 = 0; - } + + // Bail if LHS was a icmp that can't be decomposed into an equality. + if (!ICmpInst::isEquality(LHSCC)) + return 0; Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); Value *R11,*R12; bool ok = false; - if (match(R1, m_And(m_Value(R11), m_Value(R12)))) { - if (R11 != 0 && (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22)) { - A = R11; D = R12; E = R2; ok = true; + if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) { + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; D = R12; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { + A = R12; D = R11; + } else { + return 0; } - else - if (R12 != 0 && (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22)) { + E = R2; R1 = 0; ok = true; + } else if (match(R1, m_And(m_Value(R11), m_Value(R12)))) { + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; D = R12; E = R2; ok = true; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { A = R12; D = R11; E = R2; ok = true; } } + + // Bail if RHS was a icmp that can't be decomposed into an equality. + if (!ICmpInst::isEquality(RHSCC)) + return 0; + + // Look for ANDs in on the right side of the RHS icmp. if (!ok && match(R2, m_And(m_Value(R11), m_Value(R12)))) { - if (R11 != 0 && (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22)) { - A = R11; D = R12; E = R1; ok = true; - } - else - if (R12 != 0 && (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22)) { + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; D = R12; E = R1; ok = true; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { A = R12; D = R11; E = R1; ok = true; - } - else + } else { return 0; + } } if (!ok) return 0; @@ -582,7 +633,11 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, ICmpInst::Predicate NEWCC, llvm::InstCombiner::BuilderTy* Builder) { Value *A = 0, *B = 0, *C = 0, *D = 0, *E = 0; - unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS); + ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); + unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, + LHSCC, RHSCC); + assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) && + "foldLogOpOfMaskedICmpsHelper must return an equality predicate."); if (mask == 0) return 0; if (NEWCC == ICmpInst::ICMP_NE) @@ -631,11 +686,11 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, ConstantInt *CCst = dyn_cast(C); if (CCst == 0) return 0; - if (LHS->getPredicate() != NEWCC) + if (LHSCC != NEWCC) CCst = dyn_cast( ConstantExpr::getXor(BCst, CCst) ); ConstantInt *ECst = dyn_cast(E); if (ECst == 0) return 0; - if (RHS->getPredicate() != NEWCC) + if (RHSCC != NEWCC) ECst = dyn_cast( ConstantExpr::getXor(DCst, ECst) ); ConstantInt* MCst = dyn_cast( ConstantExpr::getAnd(ConstantExpr::getAnd(BCst, DCst), @@ -694,18 +749,6 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *NewOr = Builder->CreateOr(Val, Val2); return Builder->CreateICmp(LHSCC, NewOr, LHSCst); } - - // (icmp slt A, 0) & (icmp slt B, 0) --> (icmp slt (A&B), 0) - if (LHSCC == ICmpInst::ICMP_SLT && LHSCst->isZero()) { - Value *NewAnd = Builder->CreateAnd(Val, Val2); - return Builder->CreateICmp(LHSCC, NewAnd, LHSCst); - } - - // (icmp sgt A, -1) & (icmp sgt B, -1) --> (icmp sgt (A|B), -1) - if (LHSCC == ICmpInst::ICMP_SGT && LHSCst->isAllOnesValue()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); - } } // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 @@ -744,21 +787,6 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { } } - // (X & C) == 0 & X > -1 -> (X & (C | SignBit)) == 0 - if ((LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero() && - RHSCC == ICmpInst::ICMP_SGT && RHSCst->isAllOnesValue()) || - (RHSCC == ICmpInst::ICMP_EQ && RHSCst->isZero() && - LHSCC == ICmpInst::ICMP_SGT && LHSCst->isAllOnesValue())) { - ICmpInst *I = LHSCC == ICmpInst::ICMP_EQ ? LHS : RHS; - Value *X; ConstantInt *C; - if (I->hasOneUse() && - match(I->getOperand(0), m_OneUse(m_And(m_Value(X), m_ConstantInt(C))))){ - APInt New = C->getValue() | APInt::getSignBit(C->getBitWidth()); - return Builder->CreateICmpEQ(Builder->CreateAnd(X, Builder->getInt(New)), - I->getOperand(1)); - } - } - // From here on, we only handle: // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. if (Val != Val2) return 0; @@ -1443,33 +1471,6 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *NewOr = Builder->CreateOr(Val, Val2); return Builder->CreateICmp(LHSCC, NewOr, LHSCst); } - - // (icmp slt A, 0) | (icmp slt B, 0) --> (icmp slt (A|B), 0) - if (LHSCC == ICmpInst::ICMP_SLT && LHSCst->isZero()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); - } - - // (icmp sgt A, -1) | (icmp sgt B, -1) --> (icmp sgt (A&B), -1) - if (LHSCC == ICmpInst::ICMP_SGT && LHSCst->isAllOnesValue()) { - Value *NewAnd = Builder->CreateAnd(Val, Val2); - return Builder->CreateICmp(LHSCC, NewAnd, LHSCst); - } - } - - // (X & C) != 0 | X < 0 -> (X & (C | SignBit)) != 0 - if ((LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero() && - RHSCC == ICmpInst::ICMP_SLT && RHSCst->isZero()) || - (RHSCC == ICmpInst::ICMP_NE && RHSCst->isZero() && - LHSCC == ICmpInst::ICMP_SLT && LHSCst->isZero())) { - ICmpInst *I = LHSCC == ICmpInst::ICMP_NE ? LHS : RHS; - Value *X; ConstantInt *C; - if (I->hasOneUse() && - match(I->getOperand(0), m_OneUse(m_And(m_Value(X), m_ConstantInt(C))))){ - APInt New = C->getValue() | APInt::getSignBit(C->getBitWidth()); - return Builder->CreateICmpNE(Builder->CreateAnd(X, Builder->getInt(New)), - I->getOperand(1)); - } } // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1) diff --git a/test/Transforms/InstCombine/sign-test-and-or.ll b/test/Transforms/InstCombine/sign-test-and-or.ll index 3f2141d7a9f..a6066d80020 100644 --- a/test/Transforms/InstCombine/sign-test-and-or.ll +++ b/test/Transforms/InstCombine/sign-test-and-or.ll @@ -157,3 +157,23 @@ if.then: if.end: ret void } + +define void @test9(i32 %a) nounwind { + %1 = and i32 %a, 1073741824 + %2 = icmp ne i32 %1, 0 + %3 = icmp sgt i32 %a, -1 + %or.cond = and i1 %2, %3 + br i1 %or.cond, label %if.then, label %if.end + +; CHECK: @test9 +; CHECK-NEXT: %1 = and i32 %a, -1073741824 +; CHECK-NEXT: %2 = icmp eq i32 %1, 1073741824 +; CHECK-NEXT: br i1 %2, label %if.then, label %if.end + +if.then: + tail call void @foo() nounwind + ret void + +if.end: + ret void +}