[InstCombine] reassociateShiftAmtsOfTwoSameDirectionShifts(): fix miscompile (PR44802)

As input, we have the following pattern:
  Sh0 (Sh1 X, Q), K
We want to rewrite that as:
  Sh x, (Q+K)  iff (Q+K) u< bitwidth(x)
While we know that originally (Q+K) would not overflow
(because  2 * (N-1) u<= iN -1), we may have looked past extensions of
shift amounts. so it may now overflow in smaller bitwidth.

To ensure that does not happen, we need to ensure that the total maximal
shift amount is still representable in that smaller bitwidth.
If the overflow would happen, (Q+K) u< bitwidth(x) check would be bogus.

https://bugs.llvm.org/show_bug.cgi?id=44802
This commit is contained in:
Roman Lebedev 2020-02-25 16:48:36 +03:00
parent d971ae303f
commit 4e2b207539
2 changed files with 27 additions and 4 deletions

View File

@ -23,8 +23,11 @@ using namespace PatternMatch;
// Given pattern:
// (x shiftopcode Q) shiftopcode K
// we should rewrite it as
// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x)
// This is valid for any shift, but they must be identical.
// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and
//
// This is valid for any shift, but they must be identical, and we must be
// careful in case we have (zext(Q)+zext(K)) and look past extensions,
// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus.
//
// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
// pattern has any 2 right-shifts that sum to 1 less than original bit width.
@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
if (ShAmt0->getType() != ShAmt1->getType())
return nullptr;
// As input, we have the following pattern:
// Sh0 (Sh1 X, Q), K
// We want to rewrite that as:
// Sh x, (Q+K) iff (Q+K) u< bitwidth(x)
// While we know that originally (Q+K) would not overflow
// (because 2 * (N-1) u<= iN -1), we have looked past extensions of
// shift amounts. so it may now overflow in smaller bitwidth.
// To ensure that does not happen, we need to ensure that the total maximal
// shift amount is still representable in that smaller bit width.
unsigned MaximalPossibleTotalShiftAmount =
(Sh0->getType()->getScalarSizeInBits() - 1) +
(Sh1->getType()->getScalarSizeInBits() - 1);
APInt MaximalRepresentableShiftAmount =
APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
return nullptr;
// We are only looking for signbit extraction if we have two right shifts.
bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
match(Sh1, m_Shr(m_Value(), m_Value()));

View File

@ -320,12 +320,15 @@ define i32 @n20(i32 %x, i32 %y) {
ret i32 %t3
}
; FIXME: this is a miscompile. We should not transform this.
; See https://bugs.llvm.org/show_bug.cgi?id=44802
define i3 @pr44802(i3 %t0) {
; CHECK-LABEL: @pr44802(
; CHECK-NEXT: [[T1:%.*]] = sub i3 0, [[T0:%.*]]
; CHECK-NEXT: ret i3 [[T1]]
; CHECK-NEXT: [[T2:%.*]] = icmp ne i3 [[T0]], 0
; CHECK-NEXT: [[T3:%.*]] = zext i1 [[T2]] to i3
; CHECK-NEXT: [[T4:%.*]] = lshr i3 [[T1]], [[T3]]
; CHECK-NEXT: [[T5:%.*]] = lshr i3 [[T4]], [[T3]]
; CHECK-NEXT: ret i3 [[T5]]
;
%t1 = sub i3 0, %t0
%t2 = icmp ne i3 %t0, 0