[DemandedBits] Improve accuracy of Add propagator

The current demand propagator for addition will mark all input bits at and right of the alive output bit as alive. But carry won't propagate beyond a bit for which both operands are zero (or one/zero in the case of subtraction) so a more accurate answer is possible given known bits.

I derived a propagator by working through truth tables and using a bit-reversed addition to make demand ripple to the right, but I'm not sure how to make a convincing argument for its correctness in the comments yet. Nevertheless, here's a minimal implementation and test to get feedback.

This would help in a situation where, for example, four bytes (<128) packed into an int are added with four others SIMD-style but only one of the four results is actually read.

Known A:     0_______0_______0_______0_______
Known B:     0_______0_______0_______0_______
AOut:        00000000001000000000000000000000
AB, current: 00000000001111111111111111111111
AB, patch:   00000000001111111000000000000000

Committed on behalf of: @rrika (Erika)

Differential Revision: https://reviews.llvm.org/D72423
This commit is contained in:
Simon Pilgrim 2020-08-17 12:53:52 +01:00
parent 79d9e2cd93
commit c1f6ce0c73
7 changed files with 248 additions and 49 deletions

View File

@ -61,6 +61,20 @@ public:
void print(raw_ostream &OS);
/// Compute alive bits of one addition operand from alive output and known
/// operand bits
static APInt determineLiveOperandBitsAdd(unsigned OperandNo,
const APInt &AOut,
const KnownBits &LHS,
const KnownBits &RHS);
/// Compute alive bits of one subtraction operand from alive output and known
/// operand bits
static APInt determineLiveOperandBitsSub(unsigned OperandNo,
const APInt &AOut,
const KnownBits &LHS,
const KnownBits &RHS);
private:
void performAnalysis();
void determineLiveOperandBits(const Instruction *UserI,

View File

@ -173,7 +173,21 @@ void DemandedBits::determineLiveOperandBits(
}
break;
case Instruction::Add:
if (AOut.isMask()) {
AB = AOut;
} else {
ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
}
break;
case Instruction::Sub:
if (AOut.isMask()) {
AB = AOut;
} else {
ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
}
break;
case Instruction::Mul:
// Find the highest live output bit. We don't need any more input
// bits than that (adds, and thus subtracts, ripple only to the
@ -469,6 +483,86 @@ void DemandedBits::print(raw_ostream &OS) {
}
}
static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
const APInt &AOut,
const KnownBits &LHS,
const KnownBits &RHS,
bool CarryZero, bool CarryOne) {
assert(!(CarryZero && CarryOne) &&
"Carry can't be zero and one at the same time");
// The following check should be done by the caller, as it also indicates
// that LHS and RHS don't need to be computed.
//
// if (AOut.isMask())
// return AOut;
// Boundary bits' carry out is unaffected by their carry in.
APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
// First, the alive carry bits are determined from the alive output bits:
// Let demand ripple to the right but only up to any set bit in Bound.
// AOut = -1----
// Bound = ----1-
// ACarry&~AOut = --111-
APInt RBound = Bound.reverseBits();
APInt RAOut = AOut.reverseBits();
APInt RProp = RAOut + (RAOut | ~RBound);
APInt RACarry = RProp ^ ~RBound;
APInt ACarry = RACarry.reverseBits();
// Then, the alive input bits are determined from the alive carry bits:
APInt NeededToMaintainCarryZero;
APInt NeededToMaintainCarryOne;
if (OperandNo == 0) {
NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
NeededToMaintainCarryOne = LHS.One | ~RHS.One;
} else {
NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
NeededToMaintainCarryOne = RHS.One | ~LHS.One;
}
// As in computeForAddCarry
APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
// The below is simplified from
//
// APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
// APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
// APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
//
// APInt NeededToMaintainCarry =
// (CarryKnownZero & NeededToMaintainCarryZero) |
// (CarryKnownOne & NeededToMaintainCarryOne) |
// CarryUnknown;
APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
(PossibleSumOne | NeededToMaintainCarryOne);
APInt AB = AOut | (ACarry & NeededToMaintainCarry);
return AB;
}
APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
const APInt &AOut,
const KnownBits &LHS,
const KnownBits &RHS) {
return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
false);
}
APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
const APInt &AOut,
const KnownBits &LHS,
const KnownBits &RHS) {
KnownBits NRHS;
NRHS.Zero = RHS.One;
NRHS.One = RHS.Zero;
return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
true);
}
FunctionPass *llvm::createDemandedBitsWrapperPass() {
return new DemandedBitsWrapperPass();
}

View File

@ -1,22 +1,22 @@
; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
; CHECK-DAG: DemandedBits: 0x1f for %1 = and i32 %a, 9
; CHECK-DAG: DemandedBits: 0x1f for %2 = and i32 %b, 9
; CHECK-DAG: DemandedBits: 0x1f for %3 = and i32 %c, 13
; CHECK-DAG: DemandedBits: 0x1f for %4 = and i32 %d, 4
; CHECK-DAG: DemandedBits: 0x1f for %5 = or i32 %2, %3
; CHECK-DAG: DemandedBits: 0x1f for %6 = or i32 %4, %5
; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
; CHECK-DAG: DemandedBits: 0x1e for %1 = and i32 %a, 9
; CHECK-DAG: DemandedBits: 0x1a for %2 = and i32 %b, 9
; CHECK-DAG: DemandedBits: 0x1a for %3 = and i32 %c, 13
; CHECK-DAG: DemandedBits: 0x1a for %4 = and i32 %d, 4
; CHECK-DAG: DemandedBits: 0x1a for %5 = or i32 %2, %3
; CHECK-DAG: DemandedBits: 0x1a for %6 = or i32 %4, %5
; CHECK-DAG: DemandedBits: 0x10 for %7 = add i32 %1, %6
; CHECK-DAG: DemandedBits: 0xffffffff for %8 = and i32 %7, 16
define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
%1 = and i32 %a, 9
%2 = and i32 %b, 9
%3 = and i32 %c, 13
%4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
%5 = or i32 %2, %3
%6 = or i32 %4, %5
%7 = add i32 %1, %6
%8 = and i32 %7, 16
ret i32 %8
}
define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
%1 = and i32 %a, 9
%2 = and i32 %b, 9
%3 = and i32 %c, 13
%4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
%5 = or i32 %2, %3
%6 = or i32 %4, %5
%7 = add i32 %1, %6
%8 = and i32 %7, 16
ret i32 %8
}

View File

@ -18,6 +18,7 @@ add_llvm_unittest(IRTests
DataLayoutTest.cpp
DebugInfoTest.cpp
DebugTypeODRUniquingTest.cpp
DemandedBitsTest.cpp
DominatorTreeTest.cpp
DominatorTreeBatchUpdatesTest.cpp
FunctionTest.cpp

View File

@ -0,0 +1,66 @@
//===- DemandedBitsTest.cpp - DemandedBits tests --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/DemandedBits.h"
#include "../Support/KnownBitsTest.h"
#include "llvm/Support/KnownBits.h"
#include "gtest/gtest.h"
using namespace llvm;
namespace {
template <typename Fn1, typename Fn2>
static void TestBinOpExhaustive(Fn1 PropagateFn, Fn2 EvalFn) {
unsigned Bits = 4;
unsigned Max = 1 << Bits;
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
for (unsigned AOut_ = 0; AOut_ < Max; AOut_++) {
APInt AOut(Bits, AOut_);
APInt AB1 = PropagateFn(0, AOut, Known1, Known2);
APInt AB2 = PropagateFn(1, AOut, Known1, Known2);
{
// If the propagator claims that certain known bits
// didn't matter, check it doesn't change its mind
// when they become unknown.
KnownBits Known1Redacted;
KnownBits Known2Redacted;
Known1Redacted.Zero = Known1.Zero & AB1;
Known1Redacted.One = Known1.One & AB1;
Known2Redacted.Zero = Known2.Zero & AB2;
Known2Redacted.One = Known2.One & AB2;
APInt AB1R = PropagateFn(0, AOut, Known1Redacted, Known2Redacted);
APInt AB2R = PropagateFn(1, AOut, Known1Redacted, Known2Redacted);
EXPECT_EQ(AB1, AB1R);
EXPECT_EQ(AB2, AB2R);
}
ForeachNumInKnownBits(Known1, [&](APInt Value1) {
ForeachNumInKnownBits(Known2, [&](APInt Value2) {
APInt ReferenceResult = EvalFn((Value1 & AB1), (Value2 & AB2));
APInt Result = EvalFn(Value1, Value2);
EXPECT_EQ(Result & AOut, ReferenceResult & AOut);
});
});
}
});
});
}
TEST(DemandedBitsTest, Add) {
TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsAdd,
[](APInt N1, APInt N2) -> APInt { return N1 + N2; });
}
TEST(DemandedBitsTest, Sub) {
TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsSub,
[](APInt N1, APInt N2) -> APInt { return N1 - N2; });
}
} // anonymous namespace

View File

@ -11,41 +11,13 @@
//===----------------------------------------------------------------------===//
#include "llvm/Support/KnownBits.h"
#include "KnownBitsTest.h"
#include "gtest/gtest.h"
using namespace llvm;
namespace {
template<typename FnTy>
void ForeachKnownBits(unsigned Bits, FnTy Fn) {
unsigned Max = 1 << Bits;
KnownBits Known(Bits);
for (unsigned Zero = 0; Zero < Max; ++Zero) {
for (unsigned One = 0; One < Max; ++One) {
Known.Zero = Zero;
Known.One = One;
if (Known.hasConflict())
continue;
Fn(Known);
}
}
}
template<typename FnTy>
void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
unsigned Bits = Known.getBitWidth();
unsigned Max = 1 << Bits;
for (unsigned N = 0; N < Max; ++N) {
APInt Num(Bits, N);
if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
continue;
Fn(Num);
}
}
TEST(KnownBitsTest, AddCarryExhaustive) {
unsigned Bits = 4;
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {

View File

@ -0,0 +1,52 @@
//===- llvm/unittest/Support/KnownBitsTest.h - KnownBits tests ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements helpers for KnownBits and DemandedBits unit tests.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
#define LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
#include "llvm/Support/KnownBits.h"
namespace {
using namespace llvm;
template <typename FnTy> void ForeachKnownBits(unsigned Bits, FnTy Fn) {
unsigned Max = 1 << Bits;
KnownBits Known(Bits);
for (unsigned Zero = 0; Zero < Max; ++Zero) {
for (unsigned One = 0; One < Max; ++One) {
Known.Zero = Zero;
Known.One = One;
if (Known.hasConflict())
continue;
Fn(Known);
}
}
}
template <typename FnTy>
void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
unsigned Bits = Known.getBitWidth();
unsigned Max = 1 << Bits;
for (unsigned N = 0; N < Max; ++N) {
APInt Num(Bits, N);
if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
continue;
Fn(Num);
}
}
} // end anonymous namespace
#endif