Remove spurious mask operations from AArch64 add->compares on 16 and 8 bit values

This patch checks for DAG patterns that are an add or a sub followed by a
compare on 16 and 8 bit inputs. Since AArch64 does not support those types
natively they are legalized into 32 bit values, which means that mask operations
are inserted into the DAG to emulate overflow behaviour. In many cases those
masks do not change the result of the processing and just introduce a dependent
operation, often in the middle of a hot loop.

This patch detects the relevent DAG patterns and then tests to see if the
transforms are equivalent with and without the mask, removing the mask if
possible. The exact mechanism of this patch was discusses in
http://lists.cs.uiuc.edu/pipermail/llvmdev/2014-July/074444.html

There is a reasonably good chance there are missed oppurtunities due to similiar
(but not identical) DAG patterns that could be funneled into this test, adding
them should be simple if we see test cases.

Tests included.

rdar://13754426

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@216776 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Louis Gerbarg 2014-08-29 21:00:22 +00:00
parent 1469e29334
commit 6393b3a677
2 changed files with 532 additions and 0 deletions

View File

@ -7903,10 +7903,271 @@ static SDValue performNEONPostLDSTCombine(SDNode *N,
return SDValue(); return SDValue();
} }
// Checks to see if the value is the prescribed width and returns information
// about its extension mode.
static
bool checkValueWidth(SDValue V, unsigned width, ISD::LoadExtType &ExtType) {
ExtType = ISD::NON_EXTLOAD;
switch(V.getNode()->getOpcode()) {
default:
return false;
case ISD::LOAD: {
LoadSDNode *LoadNode = cast<LoadSDNode>(V.getNode());
if ((LoadNode->getMemoryVT() == MVT::i8 && width == 8)
|| (LoadNode->getMemoryVT() == MVT::i16 && width == 16)) {
ExtType = LoadNode->getExtensionType();
return true;
}
return false;
}
case ISD::AssertSext: {
VTSDNode *TypeNode = cast<VTSDNode>(V.getNode()->getOperand(1));
if ((TypeNode->getVT() == MVT::i8 && width == 8)
|| (TypeNode->getVT() == MVT::i16 && width == 16)) {
ExtType = ISD::SEXTLOAD;
return true;
}
return false;
}
case ISD::AssertZext: {
VTSDNode *TypeNode = cast<VTSDNode>(V.getNode()->getOperand(1));
if ((TypeNode->getVT() == MVT::i8 && width == 8)
|| (TypeNode->getVT() == MVT::i16 && width == 16)) {
ExtType = ISD::ZEXTLOAD;
return true;
}
return false;
}
case ISD::Constant:
case ISD::TargetConstant: {
if(abs(cast<ConstantSDNode>(V.getNode())->getSExtValue()) < 1<<(width-1))
return true;
return false;
}
}
return true;
}
// This function does a whole lot of voodoo to determine if the tests are
// equivalent without and with a mask. Essentially what happens is that given a
// DAG resembling:
//
// +-------------+ +-------------+ +-------------+ +-------------+
// | Input | | AddConstant | | CompConstant| | CC |
// +-------------+ +-------------+ +-------------+ +-------------+
// | | | |
// V V | +----------+
// +-------------+ +----+ | |
// | ADD | |0xff| | |
// +-------------+ +----+ | |
// | | | |
// V V | |
// +-------------+ | |
// | AND | | |
// +-------------+ | |
// | | |
// +-----+ | |
// | | |
// V V V
// +-------------+
// | CMP |
// +-------------+
//
// The AND node may be safely removed for some combinations of inputs. In
// particular we need to take into account the extension type of the Input,
// the exact values of AddConstant, CompConstant, and CC, along with the nominal
// width of the input (this can work for any width inputs, the above graph is
// specific to 8 bits.
//
// The specific equations were worked out by generating output tables for each
// AArch64CC value in terms of and AddConstant (w1), CompConstant(w2). The
// problem was simplified by working with 4 bit inputs, which means we only
// needed to reason about 24 distinct bit patterns: 8 patterns unique to zero
// extension (8,15), 8 patterns unique to sign extensions (-8,-1), and 8
// patterns present in both extensions (0,7). For every distinct set of
// AddConstant and CompConstants bit patterns we can consider the masked and
// unmasked versions to be equivalent if the result of this function is true for
// all 16 distinct bit patterns of for the current extension type of Input (w0).
//
// sub w8, w0, w1
// and w10, w8, #0x0f
// cmp w8, w2
// cset w9, AArch64CC
// cmp w10, w2
// cset w11, AArch64CC
// cmp w9, w11
// cset w0, eq
// ret
//
// Since the above function shows when the outputs are equivalent it defines
// when it is safe to remove the AND. Unfortunately it only runs on AArch64 and
// would be expensive to run during compiles. The equations below were written
// in a test harness that confirmed they gave equivalent outputs to the above
// for all inputs function, so they can be used determine if the removal is
// legal instead.
//
// isEquivalentMaskless() is the code for testing if the AND can be removed
// factored out of the DAG recognition as the DAG can take several forms.
static
bool isEquivalentMaskless(unsigned CC, unsigned width,
ISD::LoadExtType ExtType, signed AddConstant,
signed CompConstant) {
// By being careful about our equations and only writing the in term
// symbolic values and well known constants (0, 1, -1, MaxUInt) we can
// make them generally applicable to all bit widths.
signed MaxUInt = (1 << width);
// For the purposes of these comparisons sign extending the type is
// equivalent to zero extending the add and displacing it by half the integer
// width. Provided we are careful and make sure our equations are valid over
// the whole range we can just adjust the input and avoid writing equations
// for sign extended inputs.
if (ExtType == ISD::SEXTLOAD)
AddConstant -= (1 << (width-1));
switch(CC) {
case AArch64CC::LE:
case AArch64CC::GT: {
if ((AddConstant == 0) ||
(CompConstant == MaxUInt - 1 && AddConstant < 0) ||
(AddConstant >= 0 && CompConstant < 0) ||
(AddConstant <= 0 && CompConstant <= 0 && CompConstant < AddConstant))
return true;
} break;
case AArch64CC::LT:
case AArch64CC::GE: {
if ((AddConstant == 0) ||
(AddConstant >= 0 && CompConstant <= 0) ||
(AddConstant <= 0 && CompConstant <= 0 && CompConstant <= AddConstant))
return true;
} break;
case AArch64CC::HI:
case AArch64CC::LS: {
if ((AddConstant >= 0 && CompConstant < 0) ||
(AddConstant <= 0 && CompConstant >= -1 &&
CompConstant < AddConstant + MaxUInt))
return true;
} break;
case AArch64CC::PL:
case AArch64CC::MI: {
if ((AddConstant == 0) ||
(AddConstant > 0 && CompConstant <= 0) ||
(AddConstant < 0 && CompConstant <= AddConstant))
return true;
} break;
case AArch64CC::LO:
case AArch64CC::HS: {
if ((AddConstant >= 0 && CompConstant <= 0) ||
(AddConstant <= 0 && CompConstant >= 0 &&
CompConstant <= AddConstant + MaxUInt))
return true;
} break;
case AArch64CC::EQ:
case AArch64CC::NE: {
if ((AddConstant > 0 && CompConstant < 0) ||
(AddConstant < 0 && CompConstant >= 0 &&
CompConstant < AddConstant + MaxUInt) ||
(AddConstant >= 0 && CompConstant >= 0 &&
CompConstant >= AddConstant) ||
(AddConstant <= 0 && CompConstant < 0 && CompConstant < AddConstant))
return true;
} break;
case AArch64CC::VS:
case AArch64CC::VC:
case AArch64CC::AL:
case AArch64CC::NV:
return true;
case AArch64CC::Invalid:
break;
}
return false;
}
static
SDValue performCONDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG, unsigned CCIndex,
unsigned CmpIndex) {
unsigned CC = cast<ConstantSDNode>(N->getOperand(CCIndex))->getSExtValue();
SDNode *SubsNode = N->getOperand(CmpIndex).getNode();
unsigned CondOpcode = SubsNode->getOpcode();
if (CondOpcode != AArch64ISD::SUBS)
return SDValue();
// There is a SUBS feeding this condition. Is it fed by a mask we can
// use?
SDNode *AndNode = SubsNode->getOperand(0).getNode();
unsigned MaskBits = 0;
if (AndNode->getOpcode() != ISD::AND)
return SDValue();
if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(AndNode->getOperand(1))) {
uint32_t CNV = CN->getZExtValue();
if (CNV == 255)
MaskBits = 8;
else if (CNV == 65535)
MaskBits = 16;
}
if (!MaskBits)
return SDValue();
SDValue AddValue = AndNode->getOperand(0);
if (AddValue.getOpcode() != ISD::ADD)
return SDValue();
// The basic dag structure is correct, grab the inputs and validate them.
SDValue AddInputValue1 = AddValue.getNode()->getOperand(0);
SDValue AddInputValue2 = AddValue.getNode()->getOperand(1);
SDValue SubsInputValue = SubsNode->getOperand(1);
// The mask is present and the provenance of all the values is a smaller type,
// lets see if the mask is superfluous.
if (!isa<ConstantSDNode>(AddInputValue2.getNode()) ||
!isa<ConstantSDNode>(SubsInputValue.getNode()))
return SDValue();
ISD::LoadExtType ExtType;
if (!checkValueWidth(SubsInputValue, MaskBits, ExtType) ||
!checkValueWidth(AddInputValue2, MaskBits, ExtType) ||
!checkValueWidth(AddInputValue1, MaskBits, ExtType) )
return SDValue();
if(!isEquivalentMaskless(CC, MaskBits, ExtType,
cast<ConstantSDNode>(AddInputValue2.getNode())->getSExtValue(),
cast<ConstantSDNode>(SubsInputValue.getNode())->getSExtValue()))
return SDValue();
// The AND is not necessary, remove it.
SDVTList VTs = DAG.getVTList(SubsNode->getValueType(0),
SubsNode->getValueType(1));
SDValue Ops[] = { AddValue, SubsNode->getOperand(1) };
SDValue NewValue = DAG.getNode(CondOpcode, SDLoc(SubsNode), VTs, Ops);
DAG.ReplaceAllUsesWith(SubsNode, NewValue.getNode());
return SDValue(N, 0);
}
// Optimize compare with zero and branch. // Optimize compare with zero and branch.
static SDValue performBRCONDCombine(SDNode *N, static SDValue performBRCONDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) { SelectionDAG &DAG) {
SDValue NV = performCONDCombine(N, DCI, DAG, 2, 3);
if (NV.getNode())
N = NV.getNode();
SDValue Chain = N->getOperand(0); SDValue Chain = N->getOperand(0);
SDValue Dest = N->getOperand(1); SDValue Dest = N->getOperand(1);
SDValue CCVal = N->getOperand(2); SDValue CCVal = N->getOperand(2);
@ -8063,6 +8324,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performSTORECombine(N, DCI, DAG, Subtarget); return performSTORECombine(N, DCI, DAG, Subtarget);
case AArch64ISD::BRCOND: case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG); return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::CSEL:
return performCONDCombine(N, DCI, DAG, 2, 3);
case AArch64ISD::DUP: case AArch64ISD::DUP:
return performPostLD1Combine(N, DCI, false); return performPostLD1Combine(N, DCI, false);
case ISD::INSERT_VECTOR_ELT: case ISD::INSERT_VECTOR_ELT:

View File

@ -0,0 +1,269 @@
; RUN: llc -O0 -fast-isel=false -mtriple=arm64-apple-darwin < %s | FileCheck %s
@board = common global [400 x i8] zeroinitializer, align 1
@next_string = common global i32 0, align 4
@string_number = common global [400 x i32] zeroinitializer, align 4
; Function Attrs: nounwind ssp
define void @new_position(i32 %pos) {
entry:
%idxprom = sext i32 %pos to i64
%arrayidx = getelementptr inbounds [400 x i8]* @board, i64 0, i64 %idxprom
%tmp = load i8* %arrayidx, align 1
%.off = add i8 %tmp, -1
%switch = icmp ult i8 %.off, 2
br i1 %switch, label %if.then, label %if.end
if.then: ; preds = %entry
%tmp1 = load i32* @next_string, align 4
%arrayidx8 = getelementptr inbounds [400 x i32]* @string_number, i64 0, i64 %idxprom
store i32 %tmp1, i32* %arrayidx8, align 4
br label %if.end
if.end: ; preds = %if.then, %entry
ret void
; CHECK-LABEL: new_position
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test8_0(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, 74
%1 = icmp ult i8 %0, -20
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_0
; CHECK: and
; CHECK: ret
}
define zeroext i1 @test8_1(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, 246
%1 = icmp uge i8 %0, 90
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_1
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test8_2(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, 227
%1 = icmp ne i8 %0, 179
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_2
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test8_3(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, 201
%1 = icmp eq i8 %0, 154
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_3
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test8_4(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, -79
%1 = icmp ne i8 %0, -40
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_4
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test8_5(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, 133
%1 = icmp uge i8 %0, -105
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_5
; CHECK: and
; CHECK: ret
}
define zeroext i1 @test8_6(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, -58
%1 = icmp uge i8 %0, 155
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_6
; CHECK: and
; CHECK: ret
}
define zeroext i1 @test8_7(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, 225
%1 = icmp ult i8 %0, 124
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_7
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test8_8(i8 zeroext %x) align 2 {
entry:
%0 = add i8 %x, 190
%1 = icmp uge i8 %0, 1
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test8_8
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test16_0(i16 zeroext %x) align 2 {
entry:
%0 = add i16 %x, -46989
%1 = icmp ne i16 %0, -41903
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test16_0
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test16_2(i16 zeroext %x) align 2 {
entry:
%0 = add i16 %x, 16882
%1 = icmp ule i16 %0, -24837
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test16_2
; CHECK: and
; CHECK: ret
}
define zeroext i1 @test16_3(i16 zeroext %x) align 2 {
entry:
%0 = add i16 %x, 29283
%1 = icmp ne i16 %0, 16947
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test16_3
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test16_4(i16 zeroext %x) align 2 {
entry:
%0 = add i16 %x, -35551
%1 = icmp uge i16 %0, 15677
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test16_4
; CHECK: and
; CHECK: ret
}
define zeroext i1 @test16_5(i16 zeroext %x) align 2 {
entry:
%0 = add i16 %x, -25214
%1 = icmp ne i16 %0, -1932
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test16_5
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test16_6(i16 zeroext %x) align 2 {
entry:
%0 = add i16 %x, -32194
%1 = icmp uge i16 %0, -41215
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test16_6
; CHECK-NOT: and
; CHECK: ret
}
define zeroext i1 @test16_7(i16 zeroext %x) align 2 {
entry:
%0 = add i16 %x, 9272
%1 = icmp uge i16 %0, -42916
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test16_7
; CHECK: and
; CHECK: ret
}
define zeroext i1 @test16_8(i16 zeroext %x) align 2 {
entry:
%0 = add i16 %x, -63749
%1 = icmp ne i16 %0, 6706
br i1 %1, label %ret_true, label %ret_false
ret_false:
ret i1 false
ret_true:
ret i1 true
; CHECK-LABEL: test16_8
; CHECK-NOT: and
; CHECK: ret
}