[X86] Attempt to match multiple binary reduction ops at once. NFCI

matchBinOpReduction currently matches against a single opcode, but we already have a case where we repeat calls to try to match against AND/OR and I'll be shortly adding another case for SMAX/SMIN/UMAX/UMIN (D39729).

This NFCI patch alters matchBinOpReduction to try and pattern match against any of the provided list of candidate bin ops at once to save time.

Differential Revision: https://reviews.llvm.org/D39726

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@317985 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Simon Pilgrim 2017-11-11 18:16:55 +00:00
parent dbf8de9323
commit 90159860ad

View File

@ -30093,16 +30093,22 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
// the elements of a vector.
// Returns the vector that is being reduced on, or SDValue() if a reduction
// was not matched.
static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps) {
// The pattern must end in an extract from index 0.
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
!isNullConstant(Extract->getOperand(1)))
return SDValue();
unsigned Stages =
Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements());
SDValue Op = Extract->getOperand(0);
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
// Match against one of the candidate binary ops.
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
return Op.getOpcode() == BinOp;
}))
return SDValue();
// At each stage, we're looking for something that looks like:
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
@ -30113,8 +30119,9 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
// <4,5,6,7,u,u,u,u>
// <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u>
unsigned CandidateBinOp = Op.getOpcode();
for (unsigned i = 0; i < Stages; ++i) {
if (Op.getOpcode() != BinOp)
if (Op.getOpcode() != CandidateBinOp)
return SDValue();
ShuffleVectorSDNode *Shuffle =
@ -30127,8 +30134,8 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
}
// The first operand of the shuffle should be the same as the other operand
// of the add.
if (!Shuffle || (Shuffle->getOperand(0) != Op))
// of the binop.
if (!Shuffle || Shuffle->getOperand(0) != Op)
return SDValue();
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
@ -30137,6 +30144,7 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
return SDValue();
}
BinOp = CandidateBinOp;
return Op;
}
@ -30250,15 +30258,15 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
return SDValue();
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
for (ISD::NodeType Op : {ISD::OR, ISD::AND}) {
SDValue Match = matchBinOpReduction(Extract, Op);
unsigned BinOp = 0;
SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
if (!Match)
continue;
return SDValue();
// EXTRACT_VECTOR_ELT can require implicit extension of the vector element
// which we can't support here for now.
if (Match.getScalarValueSizeInBits() != BitWidth)
continue;
return SDValue();
// We require AVX2 for PMOVMSKB for v16i16/v32i8;
unsigned MatchSizeInBits = Match.getValueSizeInBits();
@ -30285,7 +30293,7 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
APInt CompareBits;
ISD::CondCode CondCode;
if (Op == ISD::OR) {
if (BinOp == ISD::OR) {
// any_of -> MOVMSK != 0
CompareBits = APInt::getNullValue(32);
CondCode = ISD::CondCode::SETNE;
@ -30309,9 +30317,6 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
}
return SDValue();
}
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// PSADBW is only supported on SSE2 and up.
@ -30336,7 +30341,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
return SDValue();
// Match shuffle + add pyramid.
SDValue Root = matchBinOpReduction(Extract, ISD::ADD);
unsigned BinOp = 0;
SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD});
// The operand is expected to be zero extended from i8
// (verified in detectZextAbsDiff).