diff --git a/lib/Target/AArch64/AArch64ISelLowering.cpp b/lib/Target/AArch64/AArch64ISelLowering.cpp index 5ccc4bf9ff4..f2054cb1dc0 100644 --- a/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7903,10 +7903,271 @@ static SDValue performNEONPostLDSTCombine(SDNode *N, return SDValue(); } +// Checks to see if the value is the prescribed width and returns information +// about its extension mode. +static +bool checkValueWidth(SDValue V, unsigned width, ISD::LoadExtType &ExtType) { + ExtType = ISD::NON_EXTLOAD; + switch(V.getNode()->getOpcode()) { + default: + return false; + case ISD::LOAD: { + LoadSDNode *LoadNode = cast(V.getNode()); + if ((LoadNode->getMemoryVT() == MVT::i8 && width == 8) + || (LoadNode->getMemoryVT() == MVT::i16 && width == 16)) { + ExtType = LoadNode->getExtensionType(); + return true; + } + return false; + } + case ISD::AssertSext: { + VTSDNode *TypeNode = cast(V.getNode()->getOperand(1)); + if ((TypeNode->getVT() == MVT::i8 && width == 8) + || (TypeNode->getVT() == MVT::i16 && width == 16)) { + ExtType = ISD::SEXTLOAD; + return true; + } + return false; + } + case ISD::AssertZext: { + VTSDNode *TypeNode = cast(V.getNode()->getOperand(1)); + if ((TypeNode->getVT() == MVT::i8 && width == 8) + || (TypeNode->getVT() == MVT::i16 && width == 16)) { + ExtType = ISD::ZEXTLOAD; + return true; + } + return false; + } + case ISD::Constant: + case ISD::TargetConstant: { + if(abs(cast(V.getNode())->getSExtValue()) < 1<<(width-1)) + return true; + return false; + } + } + + return true; +} + +// This function does a whole lot of voodoo to determine if the tests are +// equivalent without and with a mask. Essentially what happens is that given a +// DAG resembling: +// +// +-------------+ +-------------+ +-------------+ +-------------+ +// | Input | | AddConstant | | CompConstant| | CC | +// +-------------+ +-------------+ +-------------+ +-------------+ +// | | | | +// V V | +----------+ +// +-------------+ +----+ | | +// | ADD | |0xff| | | +// +-------------+ +----+ | | +// | | | | +// V V | | +// +-------------+ | | +// | AND | | | +// +-------------+ | | +// | | | +// +-----+ | | +// | | | +// V V V +// +-------------+ +// | CMP | +// +-------------+ +// +// The AND node may be safely removed for some combinations of inputs. In +// particular we need to take into account the extension type of the Input, +// the exact values of AddConstant, CompConstant, and CC, along with the nominal +// width of the input (this can work for any width inputs, the above graph is +// specific to 8 bits. +// +// The specific equations were worked out by generating output tables for each +// AArch64CC value in terms of and AddConstant (w1), CompConstant(w2). The +// problem was simplified by working with 4 bit inputs, which means we only +// needed to reason about 24 distinct bit patterns: 8 patterns unique to zero +// extension (8,15), 8 patterns unique to sign extensions (-8,-1), and 8 +// patterns present in both extensions (0,7). For every distinct set of +// AddConstant and CompConstants bit patterns we can consider the masked and +// unmasked versions to be equivalent if the result of this function is true for +// all 16 distinct bit patterns of for the current extension type of Input (w0). +// +// sub w8, w0, w1 +// and w10, w8, #0x0f +// cmp w8, w2 +// cset w9, AArch64CC +// cmp w10, w2 +// cset w11, AArch64CC +// cmp w9, w11 +// cset w0, eq +// ret +// +// Since the above function shows when the outputs are equivalent it defines +// when it is safe to remove the AND. Unfortunately it only runs on AArch64 and +// would be expensive to run during compiles. The equations below were written +// in a test harness that confirmed they gave equivalent outputs to the above +// for all inputs function, so they can be used determine if the removal is +// legal instead. +// +// isEquivalentMaskless() is the code for testing if the AND can be removed +// factored out of the DAG recognition as the DAG can take several forms. + +static +bool isEquivalentMaskless(unsigned CC, unsigned width, + ISD::LoadExtType ExtType, signed AddConstant, + signed CompConstant) { + // By being careful about our equations and only writing the in term + // symbolic values and well known constants (0, 1, -1, MaxUInt) we can + // make them generally applicable to all bit widths. + signed MaxUInt = (1 << width); + + // For the purposes of these comparisons sign extending the type is + // equivalent to zero extending the add and displacing it by half the integer + // width. Provided we are careful and make sure our equations are valid over + // the whole range we can just adjust the input and avoid writing equations + // for sign extended inputs. + if (ExtType == ISD::SEXTLOAD) + AddConstant -= (1 << (width-1)); + + switch(CC) { + case AArch64CC::LE: + case AArch64CC::GT: { + if ((AddConstant == 0) || + (CompConstant == MaxUInt - 1 && AddConstant < 0) || + (AddConstant >= 0 && CompConstant < 0) || + (AddConstant <= 0 && CompConstant <= 0 && CompConstant < AddConstant)) + return true; + } break; + case AArch64CC::LT: + case AArch64CC::GE: { + if ((AddConstant == 0) || + (AddConstant >= 0 && CompConstant <= 0) || + (AddConstant <= 0 && CompConstant <= 0 && CompConstant <= AddConstant)) + return true; + } break; + case AArch64CC::HI: + case AArch64CC::LS: { + if ((AddConstant >= 0 && CompConstant < 0) || + (AddConstant <= 0 && CompConstant >= -1 && + CompConstant < AddConstant + MaxUInt)) + return true; + } break; + case AArch64CC::PL: + case AArch64CC::MI: { + if ((AddConstant == 0) || + (AddConstant > 0 && CompConstant <= 0) || + (AddConstant < 0 && CompConstant <= AddConstant)) + return true; + } break; + case AArch64CC::LO: + case AArch64CC::HS: { + if ((AddConstant >= 0 && CompConstant <= 0) || + (AddConstant <= 0 && CompConstant >= 0 && + CompConstant <= AddConstant + MaxUInt)) + return true; + } break; + case AArch64CC::EQ: + case AArch64CC::NE: { + if ((AddConstant > 0 && CompConstant < 0) || + (AddConstant < 0 && CompConstant >= 0 && + CompConstant < AddConstant + MaxUInt) || + (AddConstant >= 0 && CompConstant >= 0 && + CompConstant >= AddConstant) || + (AddConstant <= 0 && CompConstant < 0 && CompConstant < AddConstant)) + + return true; + } break; + case AArch64CC::VS: + case AArch64CC::VC: + case AArch64CC::AL: + case AArch64CC::NV: + return true; + case AArch64CC::Invalid: + break; + } + + return false; +} + +static +SDValue performCONDCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG, unsigned CCIndex, + unsigned CmpIndex) { + unsigned CC = cast(N->getOperand(CCIndex))->getSExtValue(); + SDNode *SubsNode = N->getOperand(CmpIndex).getNode(); + unsigned CondOpcode = SubsNode->getOpcode(); + + if (CondOpcode != AArch64ISD::SUBS) + return SDValue(); + + // There is a SUBS feeding this condition. Is it fed by a mask we can + // use? + + SDNode *AndNode = SubsNode->getOperand(0).getNode(); + unsigned MaskBits = 0; + + if (AndNode->getOpcode() != ISD::AND) + return SDValue(); + + if (ConstantSDNode *CN = dyn_cast(AndNode->getOperand(1))) { + uint32_t CNV = CN->getZExtValue(); + if (CNV == 255) + MaskBits = 8; + else if (CNV == 65535) + MaskBits = 16; + } + + if (!MaskBits) + return SDValue(); + + SDValue AddValue = AndNode->getOperand(0); + + if (AddValue.getOpcode() != ISD::ADD) + return SDValue(); + + // The basic dag structure is correct, grab the inputs and validate them. + + SDValue AddInputValue1 = AddValue.getNode()->getOperand(0); + SDValue AddInputValue2 = AddValue.getNode()->getOperand(1); + SDValue SubsInputValue = SubsNode->getOperand(1); + + // The mask is present and the provenance of all the values is a smaller type, + // lets see if the mask is superfluous. + + if (!isa(AddInputValue2.getNode()) || + !isa(SubsInputValue.getNode())) + return SDValue(); + + ISD::LoadExtType ExtType; + + if (!checkValueWidth(SubsInputValue, MaskBits, ExtType) || + !checkValueWidth(AddInputValue2, MaskBits, ExtType) || + !checkValueWidth(AddInputValue1, MaskBits, ExtType) ) + return SDValue(); + + if(!isEquivalentMaskless(CC, MaskBits, ExtType, + cast(AddInputValue2.getNode())->getSExtValue(), + cast(SubsInputValue.getNode())->getSExtValue())) + return SDValue(); + + // The AND is not necessary, remove it. + + SDVTList VTs = DAG.getVTList(SubsNode->getValueType(0), + SubsNode->getValueType(1)); + SDValue Ops[] = { AddValue, SubsNode->getOperand(1) }; + + SDValue NewValue = DAG.getNode(CondOpcode, SDLoc(SubsNode), VTs, Ops); + DAG.ReplaceAllUsesWith(SubsNode, NewValue.getNode()); + + return SDValue(N, 0); +} + // Optimize compare with zero and branch. static SDValue performBRCONDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + SDValue NV = performCONDCombine(N, DCI, DAG, 2, 3); + if (NV.getNode()) + N = NV.getNode(); SDValue Chain = N->getOperand(0); SDValue Dest = N->getOperand(1); SDValue CCVal = N->getOperand(2); @@ -8063,6 +8324,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performSTORECombine(N, DCI, DAG, Subtarget); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); + case AArch64ISD::CSEL: + return performCONDCombine(N, DCI, DAG, 2, 3); case AArch64ISD::DUP: return performPostLD1Combine(N, DCI, false); case ISD::INSERT_VECTOR_ELT: diff --git a/test/CodeGen/AArch64/and-mask-removal.ll b/test/CodeGen/AArch64/and-mask-removal.ll new file mode 100644 index 00000000000..f803b85f733 --- /dev/null +++ b/test/CodeGen/AArch64/and-mask-removal.ll @@ -0,0 +1,269 @@ +; RUN: llc -O0 -fast-isel=false -mtriple=arm64-apple-darwin < %s | FileCheck %s + +@board = common global [400 x i8] zeroinitializer, align 1 +@next_string = common global i32 0, align 4 +@string_number = common global [400 x i32] zeroinitializer, align 4 + +; Function Attrs: nounwind ssp +define void @new_position(i32 %pos) { +entry: + %idxprom = sext i32 %pos to i64 + %arrayidx = getelementptr inbounds [400 x i8]* @board, i64 0, i64 %idxprom + %tmp = load i8* %arrayidx, align 1 + %.off = add i8 %tmp, -1 + %switch = icmp ult i8 %.off, 2 + br i1 %switch, label %if.then, label %if.end + +if.then: ; preds = %entry + %tmp1 = load i32* @next_string, align 4 + %arrayidx8 = getelementptr inbounds [400 x i32]* @string_number, i64 0, i64 %idxprom + store i32 %tmp1, i32* %arrayidx8, align 4 + br label %if.end + +if.end: ; preds = %if.then, %entry + ret void +; CHECK-LABEL: new_position +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test8_0(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, 74 + %1 = icmp ult i8 %0, -20 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_0 +; CHECK: and +; CHECK: ret +} + +define zeroext i1 @test8_1(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, 246 + %1 = icmp uge i8 %0, 90 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_1 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test8_2(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, 227 + %1 = icmp ne i8 %0, 179 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_2 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test8_3(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, 201 + %1 = icmp eq i8 %0, 154 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_3 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test8_4(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, -79 + %1 = icmp ne i8 %0, -40 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_4 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test8_5(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, 133 + %1 = icmp uge i8 %0, -105 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_5 +; CHECK: and +; CHECK: ret +} + +define zeroext i1 @test8_6(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, -58 + %1 = icmp uge i8 %0, 155 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_6 +; CHECK: and +; CHECK: ret +} + +define zeroext i1 @test8_7(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, 225 + %1 = icmp ult i8 %0, 124 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_7 +; CHECK-NOT: and +; CHECK: ret +} + + + +define zeroext i1 @test8_8(i8 zeroext %x) align 2 { +entry: + %0 = add i8 %x, 190 + %1 = icmp uge i8 %0, 1 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test8_8 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test16_0(i16 zeroext %x) align 2 { +entry: + %0 = add i16 %x, -46989 + %1 = icmp ne i16 %0, -41903 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test16_0 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test16_2(i16 zeroext %x) align 2 { +entry: + %0 = add i16 %x, 16882 + %1 = icmp ule i16 %0, -24837 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test16_2 +; CHECK: and +; CHECK: ret +} + +define zeroext i1 @test16_3(i16 zeroext %x) align 2 { +entry: + %0 = add i16 %x, 29283 + %1 = icmp ne i16 %0, 16947 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test16_3 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test16_4(i16 zeroext %x) align 2 { +entry: + %0 = add i16 %x, -35551 + %1 = icmp uge i16 %0, 15677 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test16_4 +; CHECK: and +; CHECK: ret +} + +define zeroext i1 @test16_5(i16 zeroext %x) align 2 { +entry: + %0 = add i16 %x, -25214 + %1 = icmp ne i16 %0, -1932 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test16_5 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test16_6(i16 zeroext %x) align 2 { +entry: + %0 = add i16 %x, -32194 + %1 = icmp uge i16 %0, -41215 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test16_6 +; CHECK-NOT: and +; CHECK: ret +} + +define zeroext i1 @test16_7(i16 zeroext %x) align 2 { +entry: + %0 = add i16 %x, 9272 + %1 = icmp uge i16 %0, -42916 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test16_7 +; CHECK: and +; CHECK: ret +} + +define zeroext i1 @test16_8(i16 zeroext %x) align 2 { +entry: + %0 = add i16 %x, -63749 + %1 = icmp ne i16 %0, 6706 + br i1 %1, label %ret_true, label %ret_false +ret_false: + ret i1 false +ret_true: + ret i1 true +; CHECK-LABEL: test16_8 +; CHECK-NOT: and +; CHECK: ret +} +