[X86][SSE] Pulled out splat detection helper from LowerScalarVariableShift (NFCI)

Created the IsSplatValue helper from the splat detection code in LowerScalarVariableShift as a first NFC step towards improving support for splat rotations, which is an extension of PR37426.

llvm-svn: 333580
This commit is contained in:
Simon Pilgrim 2018-05-30 19:16:59 +00:00
parent 39e5243c4d
commit 693cbc43af

View File

@ -23122,6 +23122,40 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
return SDValue();
}
// Determine if V is a splat value, and return the scalar.
// TODO: Add support for SUB(SPLAT_CST, SPLAT) cases to support rotate patterns.
static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) {
// Check if this is a splat build_vector node.
if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V)) {
SDValue SplatAmt = BV->getSplatValue();
if (SplatAmt && SplatAmt.isUndef())
return SDValue();
return SplatAmt;
}
// Check if this is a shuffle node doing a splat.
ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(V);
if (!SVN || !SVN->isSplat())
return SDValue();
unsigned SplatIdx = (unsigned)SVN->getSplatIndex();
SDValue InVec = V.getOperand(0);
if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
assert((SplatIdx < InVec.getSimpleValueType().getVectorNumElements()) &&
"Unexpected shuffle index found!");
return InVec.getOperand(SplatIdx);
} else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(InVec.getOperand(2)))
if (C->getZExtValue() == SplatIdx)
return InVec.getOperand(1);
}
// Avoid introducing an extract element from a shuffle.
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
V.getValueType().getVectorElementType(), InVec,
DAG.getIntPtrConstant(SplatIdx, dl));
}
static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
@ -23135,44 +23169,13 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL :
(Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
// Peek through any EXTRACT_SUBVECTORs.
while (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR)
Amt = Amt.getOperand(0);
if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) {
SDValue BaseShAmt;
MVT EltVT = VT.getVectorElementType();
if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Amt)) {
// Check if this build_vector node is doing a splat.
// If so, then set BaseShAmt equal to the splat value.
BaseShAmt = BV->getSplatValue();
if (BaseShAmt && BaseShAmt.isUndef())
BaseShAmt = SDValue();
} else {
if (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR)
Amt = Amt.getOperand(0);
ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(Amt);
if (SVN && SVN->isSplat()) {
unsigned SplatIdx = (unsigned)SVN->getSplatIndex();
SDValue InVec = Amt.getOperand(0);
if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
assert((SplatIdx < InVec.getSimpleValueType().getVectorNumElements()) &&
"Unexpected shuffle index found!");
BaseShAmt = InVec.getOperand(SplatIdx);
} else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
if (ConstantSDNode *C =
dyn_cast<ConstantSDNode>(InVec.getOperand(2))) {
if (C->getZExtValue() == SplatIdx)
BaseShAmt = InVec.getOperand(1);
}
}
if (!BaseShAmt)
// Avoid introducing an extract element from a shuffle.
BaseShAmt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, InVec,
DAG.getIntPtrConstant(SplatIdx, dl));
}
}
if (BaseShAmt.getNode()) {
if (SDValue BaseShAmt = IsSplatValue(Amt, dl, DAG)) {
MVT EltVT = VT.getVectorElementType();
assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!");
if (EltVT != MVT::i64 && EltVT.bitsGT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, BaseShAmt);