From 79654737d7de8041d770d8dc0c148f8b62c6c21b Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 30 May 2017 17:47:51 +0000 Subject: [PATCH] [Hexagon] Improve code generation for 32x32-bit multiplication For multiplications of 64-bit values (giving 64-bit result), detect cases where the arguments are sign-extended 32-bit values, on a per- operand basis. This will allow few patterns to match a wider variety of combinations in which extensions can occur. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@304223 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/Hexagon/HexagonISelDAGToDAG.cpp | 189 ++++++++++----------- lib/Target/Hexagon/HexagonISelLowering.cpp | 6 +- lib/Target/Hexagon/HexagonPatterns.td | 59 +++---- test/CodeGen/Hexagon/mul64-sext.ll | 93 ++++++++++ 4 files changed, 208 insertions(+), 139 deletions(-) create mode 100644 test/CodeGen/Hexagon/mul64-sext.ll diff --git a/lib/Target/Hexagon/HexagonISelDAGToDAG.cpp b/lib/Target/Hexagon/HexagonISelDAGToDAG.cpp index 8e10c521a77..e4434136bf8 100644 --- a/lib/Target/Hexagon/HexagonISelDAGToDAG.cpp +++ b/lib/Target/Hexagon/HexagonISelDAGToDAG.cpp @@ -71,6 +71,9 @@ public: return true; } + bool ComplexPatternFuncMutatesDAG() const override { + return true; + } void PreprocessISelDAG() override; void EmitFunctionEntryCode() override; @@ -81,6 +84,7 @@ public: inline bool SelectAddrGP(SDValue &N, SDValue &R); bool SelectGlobalAddress(SDValue &N, SDValue &R, bool UseGP); bool SelectAddrFI(SDValue &N, SDValue &R); + bool DetectUseSxtw(SDValue &N, SDValue &R); StringRef getPassName() const override { return "Hexagon DAG->DAG Pattern Instruction Selection"; @@ -106,7 +110,6 @@ public: void SelectIndexedStore(StoreSDNode *ST, const SDLoc &dl); void SelectStore(SDNode *N); void SelectSHL(SDNode *N); - void SelectMul(SDNode *N); void SelectZeroExtend(SDNode *N); void SelectIntrinsicWChain(SDNode *N); void SelectIntrinsicWOChain(SDNode *N); @@ -118,7 +121,7 @@ public: #include "HexagonGenDAGISel.inc" private: - bool isValueExtension(const SDValue &Val, unsigned FromBits, SDValue &Src); + bool keepsLowBits(const SDValue &Val, unsigned NumBits, SDValue &Src); bool isOrEquivalentToAdd(const SDNode *N) const; bool isAlignedMemNode(const MemSDNode *N) const; bool isPositiveHalfWord(const SDNode *N) const; @@ -597,90 +600,6 @@ void HexagonDAGToDAGISel::SelectStore(SDNode *N) { SelectCode(ST); } -void HexagonDAGToDAGISel::SelectMul(SDNode *N) { - SDLoc dl(N); - - // %conv.i = sext i32 %tmp1 to i64 - // %conv2.i = sext i32 %add to i64 - // %mul.i = mul nsw i64 %conv2.i, %conv.i - // - // --- match with the following --- - // - // %mul.i = mpy (%tmp1, %add) - // - - if (N->getValueType(0) == MVT::i64) { - // Shifting a i64 signed multiply. - SDValue MulOp0 = N->getOperand(0); - SDValue MulOp1 = N->getOperand(1); - - SDValue OP0; - SDValue OP1; - - // Handle sign_extend and sextload. - if (MulOp0.getOpcode() == ISD::SIGN_EXTEND) { - SDValue Sext0 = MulOp0.getOperand(0); - if (Sext0.getNode()->getValueType(0) != MVT::i32) { - SelectCode(N); - return; - } - OP0 = Sext0; - } else if (MulOp0.getOpcode() == ISD::LOAD) { - LoadSDNode *LD = cast(MulOp0.getNode()); - if (LD->getMemoryVT() != MVT::i32 || - LD->getExtensionType() != ISD::SEXTLOAD || - LD->getAddressingMode() != ISD::UNINDEXED) { - SelectCode(N); - return; - } - SDValue Chain = LD->getChain(); - SDValue TargetConst0 = CurDAG->getTargetConstant(0, dl, MVT::i32); - OP0 = SDValue(CurDAG->getMachineNode(Hexagon::L2_loadri_io, dl, MVT::i32, - MVT::Other, - LD->getBasePtr(), TargetConst0, - Chain), 0); - } else { - SelectCode(N); - return; - } - - // Same goes for the second operand. - if (MulOp1.getOpcode() == ISD::SIGN_EXTEND) { - SDValue Sext1 = MulOp1.getOperand(0); - if (Sext1.getNode()->getValueType(0) != MVT::i32) { - SelectCode(N); - return; - } - OP1 = Sext1; - } else if (MulOp1.getOpcode() == ISD::LOAD) { - LoadSDNode *LD = cast(MulOp1.getNode()); - if (LD->getMemoryVT() != MVT::i32 || - LD->getExtensionType() != ISD::SEXTLOAD || - LD->getAddressingMode() != ISD::UNINDEXED) { - SelectCode(N); - return; - } - SDValue Chain = LD->getChain(); - SDValue TargetConst0 = CurDAG->getTargetConstant(0, dl, MVT::i32); - OP1 = SDValue(CurDAG->getMachineNode(Hexagon::L2_loadri_io, dl, MVT::i32, - MVT::Other, - LD->getBasePtr(), TargetConst0, - Chain), 0); - } else { - SelectCode(N); - return; - } - - // Generate a mpy instruction. - SDNode *Result = CurDAG->getMachineNode(Hexagon::M2_dpmpyss_s0, dl, - MVT::i64, OP0, OP1); - ReplaceNode(N, Result); - return; - } - - SelectCode(N); -} - void HexagonDAGToDAGISel::SelectSHL(SDNode *N) { SDLoc dl(N); SDValue Shl_0 = N->getOperand(0); @@ -843,7 +762,7 @@ void HexagonDAGToDAGISel::SelectIntrinsicWOChain(SDNode *N) { SDValue V = N->getOperand(1); SDValue U; - if (isValueExtension(V, Bits, U)) { + if (keepsLowBits(V, Bits, U)) { SDValue R = CurDAG->getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), N->getOperand(0), U); ReplaceNode(N, R.getNode()); @@ -949,7 +868,6 @@ void HexagonDAGToDAGISel::Select(SDNode *N) { case ISD::SHL: return SelectSHL(N); case ISD::LOAD: return SelectLoad(N); case ISD::STORE: return SelectStore(N); - case ISD::MUL: return SelectMul(N); case ISD::ZERO_EXTEND: return SelectZeroExtend(N); case ISD::INTRINSIC_W_CHAIN: return SelectIntrinsicWChain(N); case ISD::INTRINSIC_WO_CHAIN: return SelectIntrinsicWOChain(N); @@ -1327,7 +1245,7 @@ void HexagonDAGToDAGISel::EmitFunctionEntryCode() { } // Match a frame index that can be used in an addressing mode. -bool HexagonDAGToDAGISel::SelectAddrFI(SDValue& N, SDValue &R) { +bool HexagonDAGToDAGISel::SelectAddrFI(SDValue &N, SDValue &R) { if (N.getOpcode() != ISD::FrameIndex) return false; auto &HFI = *HST->getFrameLowering(); @@ -1388,16 +1306,83 @@ bool HexagonDAGToDAGISel::SelectGlobalAddress(SDValue &N, SDValue &R, return false; } -bool HexagonDAGToDAGISel::isValueExtension(const SDValue &Val, - unsigned FromBits, SDValue &Src) { +bool HexagonDAGToDAGISel::DetectUseSxtw(SDValue &N, SDValue &R) { + // This (complex pattern) function is meant to detect a sign-extension + // i32->i64 on a per-operand basis. This would allow writing single + // patterns that would cover a number of combinations of different ways + // a sign-extensions could be written. For example: + // (mul (DetectUseSxtw x) (DetectUseSxtw y)) -> (M2_dpmpyss_s0 x y) + // could match either one of these: + // (mul (sext x) (sext_inreg y)) + // (mul (sext-load *p) (sext_inreg y)) + // (mul (sext_inreg x) (sext y)) + // etc. + // + // The returned value will have type i64 and its low word will + // contain the value being extended. The high bits are not specified. + // The returned type is i64 because the original type of N was i64, + // but the users of this function should only use the low-word of the + // result, e.g. + // (mul sxtw:x, sxtw:y) -> (M2_dpmpyss_s0 (LoReg sxtw:x), (LoReg sxtw:y)) + + if (N.getValueType() != MVT::i64) + return false; + EVT SrcVT; + unsigned Opc = N.getOpcode(); + switch (Opc) { + case ISD::SIGN_EXTEND: + case ISD::SIGN_EXTEND_INREG: { + // sext_inreg has the source type as a separate operand. + EVT T = Opc == ISD::SIGN_EXTEND + ? N.getOperand(0).getValueType() + : cast(N.getOperand(1))->getVT(); + if (T.getSizeInBits() != 32) + return false; + R = N.getOperand(0); + break; + } + case ISD::LOAD: { + LoadSDNode *L = cast(N); + if (L->getExtensionType() != ISD::SEXTLOAD) + return false; + // All extending loads extend to i32, so even if the value in + // memory is shorter than 32 bits, it will be i32 after the load. + if (L->getMemoryVT().getSizeInBits() > 32) + return false; + R = N; + break; + } + default: + return false; + } + EVT RT = R.getValueType(); + if (RT == MVT::i64) + return true; + assert(RT == MVT::i32); + // This is only to produce a value of type i64. Do not rely on the + // high bits produced by this. + const SDLoc &dl(N); + SDValue Ops[] = { + CurDAG->getTargetConstant(Hexagon::DoubleRegsRegClassID, dl, MVT::i32), + R, CurDAG->getTargetConstant(Hexagon::isub_hi, dl, MVT::i32), + R, CurDAG->getTargetConstant(Hexagon::isub_lo, dl, MVT::i32) + }; + SDNode *T = CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, dl, + MVT::i64, Ops); + R = SDValue(T, 0); + return true; +} + +bool HexagonDAGToDAGISel::keepsLowBits(const SDValue &Val, unsigned NumBits, + SDValue &Src) { unsigned Opc = Val.getOpcode(); switch (Opc) { case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: case ISD::ANY_EXTEND: { - SDValue const &Op0 = Val.getOperand(0); + const SDValue &Op0 = Val.getOperand(0); EVT T = Op0.getValueType(); - if (T.isInteger() && T.getSizeInBits() == FromBits) { + if (T.isInteger() && T.getSizeInBits() == NumBits) { Src = Op0; return true; } @@ -1408,23 +1393,23 @@ bool HexagonDAGToDAGISel::isValueExtension(const SDValue &Val, case ISD::AssertZext: if (Val.getOperand(0).getValueType().isInteger()) { VTSDNode *T = cast(Val.getOperand(1)); - if (T->getVT().getSizeInBits() == FromBits) { + if (T->getVT().getSizeInBits() == NumBits) { Src = Val.getOperand(0); return true; } } break; case ISD::AND: { - // Check if this is an AND with "FromBits" of lower bits set to 1. - uint64_t FromMask = (1 << FromBits) - 1; + // Check if this is an AND with NumBits of lower bits set to 1. + uint64_t Mask = (1 << NumBits) - 1; if (ConstantSDNode *C = dyn_cast(Val.getOperand(0))) { - if (C->getZExtValue() == FromMask) { + if (C->getZExtValue() == Mask) { Src = Val.getOperand(1); return true; } } if (ConstantSDNode *C = dyn_cast(Val.getOperand(1))) { - if (C->getZExtValue() == FromMask) { + if (C->getZExtValue() == Mask) { Src = Val.getOperand(0); return true; } @@ -1433,16 +1418,16 @@ bool HexagonDAGToDAGISel::isValueExtension(const SDValue &Val, } case ISD::OR: case ISD::XOR: { - // OR/XOR with the lower "FromBits" bits set to 0. - uint64_t FromMask = (1 << FromBits) - 1; + // OR/XOR with the lower NumBits bits set to 0. + uint64_t Mask = (1 << NumBits) - 1; if (ConstantSDNode *C = dyn_cast(Val.getOperand(0))) { - if ((C->getZExtValue() & FromMask) == 0) { + if ((C->getZExtValue() & Mask) == 0) { Src = Val.getOperand(1); return true; } } if (ConstantSDNode *C = dyn_cast(Val.getOperand(1))) { - if ((C->getZExtValue() & FromMask) == 0) { + if ((C->getZExtValue() & Mask) == 0) { Src = Val.getOperand(0); return true; } diff --git a/lib/Target/Hexagon/HexagonISelLowering.cpp b/lib/Target/Hexagon/HexagonISelLowering.cpp index 5ecf9320d5c..4c6c6eeafbe 100644 --- a/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -1928,11 +1928,7 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); setOperationAction(ISD::BSWAP, MVT::i32, Legal); setOperationAction(ISD::BSWAP, MVT::i64, Legal); - - // We custom lower i64 to i64 mul, so that it is not considered as a legal - // operation. There is a pattern that will match i64 mul and transform it - // to a series of instructions. - setOperationAction(ISD::MUL, MVT::i64, Expand); + setOperationAction(ISD::MUL, MVT::i64, Legal); for (unsigned IntExpOp : { ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM, diff --git a/lib/Target/Hexagon/HexagonPatterns.td b/lib/Target/Hexagon/HexagonPatterns.td index 81b5e10c117..70ed123bc89 100644 --- a/lib/Target/Hexagon/HexagonPatterns.td +++ b/lib/Target/Hexagon/HexagonPatterns.td @@ -382,48 +382,42 @@ def: T_MType_acc_pat3 ; def: T_MType_acc_pat3 ; def: T_MType_acc_pat3 ; +// This complex pattern is really only to detect various forms of +// sign-extension i32->i64. The selected value will be of type i64 +// whose low word is the value being extended. The high word is +// unspecified. +def Usxtw : ComplexPattern; + def Aext64: PatFrag<(ops node:$Rs), (i64 (anyext node:$Rs))>; -def Sext64: PatFrag<(ops node:$Rs), (i64 (sext node:$Rs))>; def Zext64: PatFrag<(ops node:$Rs), (i64 (zext node:$Rs))>; +def Sext64: PatLeaf<(i64 Usxtw:$Rs)>; -// Return true if for a 32 to 64-bit sign-extended load. -def Sext64Ld : PatLeaf<(i64 DoubleRegs:$src1), [{ - LoadSDNode *LD = dyn_cast(N); - if (!LD) - return false; - return LD->getExtensionType() == ISD::SEXTLOAD && - LD->getMemoryVT().getScalarType() == MVT::i32; -}]>; +def: Pat<(mul (Aext64 I32:$Rs), (Aext64 I32:$Rt)), + (M2_dpmpyuu_s0 I32:$Rs, I32:$Rt)>; -def: Pat<(mul (Aext64 I32:$src1), (Aext64 I32:$src2)), - (M2_dpmpyuu_s0 IntRegs:$src1, IntRegs:$src2)>; - -def: Pat<(mul (Sext64 I32:$src1), (Sext64 I32:$src2)), - (M2_dpmpyss_s0 IntRegs:$src1, IntRegs:$src2)>; - -def: Pat<(mul Sext64Ld:$src1, Sext64Ld:$src2), - (M2_dpmpyss_s0 (LoReg DoubleRegs:$src1), (LoReg DoubleRegs:$src2))>; +def: Pat<(mul Sext64:$Rs, Sext64:$Rt), + (M2_dpmpyss_s0 (LoReg Sext64:$Rs), (LoReg Sext64:$Rt))>; // Multiply and accumulate, use full result. // Rxx[+-]=mpy(Rs,Rt) -def: Pat<(add I64:$src1, (mul (Sext64 I32:$src2), (Sext64 I32:$src3))), - (M2_dpmpyss_acc_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>; +def: Pat<(add I64:$Rx, (mul Sext64:$Rs, Sext64:$Rt)), + (M2_dpmpyss_acc_s0 I64:$Rx, (LoReg Sext64:$Rs), (LoReg Sext64:$Rt))>; -def: Pat<(sub I64:$src1, (mul (Sext64 I32:$src2), (Sext64 I32:$src3))), - (M2_dpmpyss_nac_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>; +def: Pat<(sub I64:$Rx, (mul Sext64:$Rs, Sext64:$Rt)), + (M2_dpmpyss_nac_s0 I64:$Rx, (LoReg Sext64:$Rs), (LoReg Sext64:$Rt))>; -def: Pat<(add I64:$src1, (mul (Aext64 I32:$src2), (Aext64 I32:$src3))), - (M2_dpmpyuu_acc_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>; +def: Pat<(add I64:$Rx, (mul (Aext64 I32:$Rs), (Aext64 I32:$Rt))), + (M2_dpmpyuu_acc_s0 I64:$Rx, I32:$Rs, I32:$Rt)>; -def: Pat<(add I64:$src1, (mul (Zext64 I32:$src2), (Zext64 I32:$src3))), - (M2_dpmpyuu_acc_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>; +def: Pat<(add I64:$Rx, (mul (Zext64 I32:$Rs), (Zext64 I32:$Rt))), + (M2_dpmpyuu_acc_s0 I64:$Rx, I32:$Rs, I32:$Rt)>; -def: Pat<(sub I64:$src1, (mul (Aext64 I32:$src2), (Aext64 I32:$src3))), - (M2_dpmpyuu_nac_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>; +def: Pat<(sub I64:$Rx, (mul (Aext64 I32:$Rs), (Aext64 I32:$Rt))), + (M2_dpmpyuu_nac_s0 I64:$Rx, I32:$Rs, I32:$Rt)>; -def: Pat<(sub I64:$src1, (mul (Zext64 I32:$src2), (Zext64 I32:$src3))), - (M2_dpmpyuu_nac_s0 DoubleRegs:$src1, IntRegs:$src2, IntRegs:$src3)>; +def: Pat<(sub I64:$Rx, (mul (Zext64 I32:$Rs), (Zext64 I32:$Rt))), + (M2_dpmpyuu_nac_s0 I64:$Rx, I32:$Rs, I32:$Rt)>; class Storepi_pat @@ -545,7 +539,8 @@ def: Storexm_simple_pat; def: Storexm_simple_pat; def: Storexm_simple_pat; -def: Pat <(Sext64 I32:$src), (A2_sxtw I32:$src)>; +def: Pat <(i64 (sext I32:$src)), (A2_sxtw I32:$src)>; +def: Pat <(i64 (sext_inreg I64:$src, i32)), (A2_sxtw (LoReg I64:$src))>; def: Pat<(select (i1 (setlt I32:$src, 0)), (sub 0, I32:$src), I32:$src), (A2_abs IntRegs:$src)>; @@ -1159,8 +1154,8 @@ multiclass MinMax_pats_p { defm: T_MinMax_pats; } -def: Pat<(add (Sext64 I32:$Rs), I64:$Rt), - (A2_addsp IntRegs:$Rs, DoubleRegs:$Rt)>; +def: Pat<(add Sext64:$Rs, I64:$Rt), + (A2_addsp (LoReg Sext64:$Rs), DoubleRegs:$Rt)>; let AddedComplexity = 200 in { defm: MinMax_pats_p; diff --git a/test/CodeGen/Hexagon/mul64-sext.ll b/test/CodeGen/Hexagon/mul64-sext.ll new file mode 100644 index 00000000000..8bbe6649a1f --- /dev/null +++ b/test/CodeGen/Hexagon/mul64-sext.ll @@ -0,0 +1,93 @@ +; RUN: llc -march=hexagon < %s | FileCheck %s + +target triple = "hexagon-unknown--elf" + +; CHECK-LABEL: mul_1 +; CHECK: r1:0 = mpy(r2,r0) +define i64 @mul_1(i64 %a0, i64 %a1) #0 { +b2: + %v3 = shl i64 %a0, 32 + %v4 = ashr exact i64 %v3, 32 + %v5 = shl i64 %a1, 32 + %v6 = ashr exact i64 %v5, 32 + %v7 = mul nsw i64 %v6, %v4 + ret i64 %v7 +} + +; CHECK-LABEL: mul_2 +; CHECK: r0 = memb(r0+#0) +; CHECK: r1:0 = mpy(r2,r0) +; CHECK: jumpr r31 +define i64 @mul_2(i8* %a0, i64 %a1) #0 { +b2: + %v3 = load i8, i8* %a0 + %v4 = sext i8 %v3 to i64 + %v5 = shl i64 %a1, 32 + %v6 = ashr exact i64 %v5, 32 + %v7 = mul nsw i64 %v6, %v4 + ret i64 %v7 +} + +; CHECK-LABEL: mul_acc_1 +; CHECK: r5:4 += mpy(r2,r0) +; CHECK: r1:0 = combine(r5,r4) +; CHECK: jumpr r31 +define i64 @mul_acc_1(i64 %a0, i64 %a1, i64 %a2) #0 { +b3: + %v4 = shl i64 %a0, 32 + %v5 = ashr exact i64 %v4, 32 + %v6 = shl i64 %a1, 32 + %v7 = ashr exact i64 %v6, 32 + %v8 = mul nsw i64 %v7, %v5 + %v9 = add i64 %a2, %v8 + ret i64 %v9 +} + +; CHECK-LABEL: mul_acc_2 +; CHECK: r2 = memw(r2+#0) +; CHECK: r5:4 += mpy(r2,r0) +; CHECK: r1:0 = combine(r5,r4) +; CHECK: jumpr r31 +define i64 @mul_acc_2(i64 %a0, i32* %a1, i64 %a2) #0 { +b3: + %v4 = shl i64 %a0, 32 + %v5 = ashr exact i64 %v4, 32 + %v6 = load i32, i32* %a1 + %v7 = sext i32 %v6 to i64 + %v8 = mul nsw i64 %v7, %v5 + %v9 = add i64 %a2, %v8 + ret i64 %v9 +} + +; CHECK-LABEL: mul_nac_1 +; CHECK: r5:4 -= mpy(r2,r0) +; CHECK: r1:0 = combine(r5,r4) +; CHECK: jumpr r31 +define i64 @mul_nac_1(i64 %a0, i64 %a1, i64 %a2) #0 { +b3: + %v4 = shl i64 %a0, 32 + %v5 = ashr exact i64 %v4, 32 + %v6 = shl i64 %a1, 32 + %v7 = ashr exact i64 %v6, 32 + %v8 = mul nsw i64 %v7, %v5 + %v9 = sub i64 %a2, %v8 + ret i64 %v9 +} + +; CHECK-LABEL: mul_nac_2 +; CHECK: r0 = memw(r0+#0) +; CHECK: r5:4 -= mpy(r2,r0) +; CHECK: r1:0 = combine(r5,r4) +; CHECK: jumpr r31 +define i64 @mul_nac_2(i32* %a0, i64 %a1, i64 %a2) #0 { +b3: + %v4 = load i32, i32* %a0 + %v5 = sext i32 %v4 to i64 + %v6 = shl i64 %a1, 32 + %v7 = ashr exact i64 %v6, 32 + %v8 = mul nsw i64 %v7, %v5 + %v9 = sub i64 %a2, %v8 + ret i64 %v9 +} + +attributes #0 = { nounwind }