diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 754b791aff87..983e73e40fb3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2042,37 +2042,21 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, Pred = ICmpInst::ICMP_UGT; } - // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 - // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 unsigned CLog2 = C.logBase2(); - if (CLog2 == TypeBits - 1) { - if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_NE; - } return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C.isAllOnes()) { - // (1 << Y) <= -1 -> Y == 31 - if (Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + // (1 << Y) > 0 -> Y != 31 + // (1 << Y) > -1 -> Y != 31 + // TODO: This can be generalized to any negative constant. + if (Pred == ICmpInst::ICMP_SGT && (C.isZero() || C.isAllOnes())) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - // (1 << Y) > -1 -> Y != 31 - if (Pred == ICmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } else if (!C) { - // (1 << Y) < 0 -> Y == 31 - // (1 << Y) <= 0 -> Y == 31 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); - - // (1 << Y) >= 0 -> Y != 31 - // (1 << Y) > 0 -> Y != 31 - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } + // (1 << Y) < 0 -> Y == 31 + // (1 << Y) < 1 -> Y == 31 + // TODO: This can be generalized to any negative constant except signed min. + if (Pred == ICmpInst::ICMP_SLT && (C.isZero() || C.isOne())) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); } else if (Cmp.isEquality() && CIsPowerOf2) { return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2())); } diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll index 60ab8ff3d179..4643278dd291 100644 --- a/llvm/test/Transforms/InstCombine/icmp.ll +++ b/llvm/test/Transforms/InstCombine/icmp.ll @@ -2160,8 +2160,7 @@ define <2 x i1> @icmp_shl_1_V_ult_2147483648_vec(<2 x i32> %V) { define i1 @icmp_shl_1_V_sle_0(i32 %V) { ; CHECK-LABEL: @icmp_shl_1_V_sle_0( -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[V:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[SHL]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[V:%.*]], 31 ; CHECK-NEXT: ret i1 [[CMP]] ; %shl = shl i32 1, %V @@ -2171,8 +2170,7 @@ define i1 @icmp_shl_1_V_sle_0(i32 %V) { define <2 x i1> @icmp_shl_1_V_sle_0_vec(<2 x i32> %V) { ; CHECK-LABEL: @icmp_shl_1_V_sle_0_vec( -; CHECK-NEXT: [[SHL:%.*]] = shl nuw <2 x i32> , [[V:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i32> [[SHL]], +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[V:%.*]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %shl = shl <2 x i32> , %V