[TargetLowering] Improve expansion of FSHL/FSHR by non-zero amount

Use a simpler code sequence when the shift amount is known not to be
zero modulo the bit width.

Nothing much uses this until D77152 changes the translation of fshl and
fshr intrinsics.

Differential Revision: https://reviews.llvm.org/D82540
This commit is contained in:
Jay Foad 2020-05-29 10:57:42 +01:00
parent 5fecaffff0
commit 4e57aaab54

View File

@ -6117,6 +6117,14 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
return Ok; return Ok;
} }
// Check that (every element of) Z is undef or not an exact multiple of BW.
static bool isNonZeroModBitWidth(SDValue Z, unsigned BW) {
return ISD::matchUnaryPredicate(
Z,
[=](ConstantSDNode *C) { return !C || C->getAPIntValue().urem(BW) != 0; },
true);
}
bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result, bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result,
SelectionDAG &DAG) const { SelectionDAG &DAG) const {
EVT VT = Node->getValueType(0); EVT VT = Node->getValueType(0);
@ -6127,40 +6135,52 @@ bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result,
!isOperationLegalOrCustomOrPromote(ISD::OR, VT))) !isOperationLegalOrCustomOrPromote(ISD::OR, VT)))
return false; return false;
// fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
// fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
SDValue X = Node->getOperand(0); SDValue X = Node->getOperand(0);
SDValue Y = Node->getOperand(1); SDValue Y = Node->getOperand(1);
SDValue Z = Node->getOperand(2); SDValue Z = Node->getOperand(2);
unsigned EltSizeInBits = VT.getScalarSizeInBits(); unsigned BW = VT.getScalarSizeInBits();
bool IsFSHL = Node->getOpcode() == ISD::FSHL; bool IsFSHL = Node->getOpcode() == ISD::FSHL;
SDLoc DL(SDValue(Node, 0)); SDLoc DL(SDValue(Node, 0));
EVT ShVT = Z.getValueType(); EVT ShVT = Z.getValueType();
SDValue Mask = DAG.getConstant(EltSizeInBits - 1, DL, ShVT);
SDValue ShAmt, InvShAmt;
if (isPowerOf2_32(EltSizeInBits)) {
// Z % BW -> Z & (BW - 1)
ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask);
// (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
InvShAmt = DAG.getNode(ISD::AND, DL, ShVT, DAG.getNOT(DL, Z, ShVT), Mask);
} else {
SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT);
ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC);
InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, Mask, ShAmt);
}
SDValue One = DAG.getConstant(1, DL, ShVT);
SDValue ShX, ShY; SDValue ShX, ShY;
if (IsFSHL) { SDValue ShAmt, InvShAmt;
ShX = DAG.getNode(ISD::SHL, DL, VT, X, ShAmt); if (isNonZeroModBitWidth(Z, BW)) {
SDValue ShY1 = DAG.getNode(ISD::SRL, DL, VT, Y, One); // fshl: X << C | Y >> (BW - C)
ShY = DAG.getNode(ISD::SRL, DL, VT, ShY1, InvShAmt); // fshr: X << (BW - C) | Y >> C
// where C = Z % BW is not zero
SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC);
InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, ShAmt);
ShX = DAG.getNode(ISD::SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt);
ShY = DAG.getNode(ISD::SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt);
} else { } else {
SDValue ShX1 = DAG.getNode(ISD::SHL, DL, VT, X, One); // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
ShX = DAG.getNode(ISD::SHL, DL, VT, ShX1, InvShAmt); // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
ShY = DAG.getNode(ISD::SRL, DL, VT, Y, ShAmt); SDValue Mask = DAG.getConstant(BW - 1, DL, ShVT);
if (isPowerOf2_32(BW)) {
// Z % BW -> Z & (BW - 1)
ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask);
// (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
InvShAmt = DAG.getNode(ISD::AND, DL, ShVT, DAG.getNOT(DL, Z, ShVT), Mask);
} else {
SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC);
InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, Mask, ShAmt);
}
SDValue One = DAG.getConstant(1, DL, ShVT);
if (IsFSHL) {
ShX = DAG.getNode(ISD::SHL, DL, VT, X, ShAmt);
SDValue ShY1 = DAG.getNode(ISD::SRL, DL, VT, Y, One);
ShY = DAG.getNode(ISD::SRL, DL, VT, ShY1, InvShAmt);
} else {
SDValue ShX1 = DAG.getNode(ISD::SHL, DL, VT, X, One);
ShX = DAG.getNode(ISD::SHL, DL, VT, ShX1, InvShAmt);
ShY = DAG.getNode(ISD::SRL, DL, VT, Y, ShAmt);
}
} }
Result = DAG.getNode(ISD::OR, DL, VT, ShX, ShY); Result = DAG.getNode(ISD::OR, DL, VT, ShX, ShY);
return true; return true;