[X86][SSE] Optimize the truncation of vector comparison results with PACKSS

We currently default to using either generic shuffles or MASK+PACKUS/PACKSS to truncate all integer vectors. For vector comparisons, we know that the result will be either all or zero bits in every element, which can be efficiently truncated by directly using PACKSS to repeatedly halve the size of each element.

Due to the limited input values (-1 or 0) we don't need to account for vector element size, so for simplicity we just use the PACKSS(vXi16,vXi16) implementation in all cases. Additionally for AVX2 PACKSS of 256bit data we must perform a PERMQ shuffle to reorder the data into the correct order. I did investigate performing a single shuffle after all the PACKSS calls but the need to cross 128bit lanes makes this difficult to achieve efficiently.

We avoid performing this on AVX512 as it should have better alternative truncation instructions.

Differential Revision: https://reviews.llvm.org/D22814

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@277132 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Simon Pilgrim 2016-07-29 10:23:10 +00:00
parent bf172ec934
commit e6abaac391
3 changed files with 475 additions and 729 deletions

View File

@ -4423,8 +4423,6 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
const SDLoc &dl, unsigned vectorWidth) {
assert((vectorWidth == 128 || vectorWidth == 256) &&
"Unsupported vector width");
EVT VT = Vec.getValueType();
EVT ElVT = VT.getVectorElementType();
unsigned Factor = VT.getSizeInBits()/vectorWidth;
@ -14132,6 +14130,85 @@ static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
return SDValue();
}
/// Helper to recursively truncate vector elements in half with PACKSS.
/// It makes use of the fact that vector comparison results will be all-zeros
/// or all-ones to use (vXi8 PACKSS(vYi16, vYi16)) instead of matching types.
/// AVX2 (Int256) sub-targets require extra shuffling as the PACKSS operates
/// within each 128-bit lane.
static SDValue truncateVectorCompareWithPACKSS(EVT DstVT, SDValue In,
const SDLoc &DL,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// AVX512 has fast truncate.
if (Subtarget.hasAVX512())
return SDValue();
EVT SrcVT = In.getValueType();
// No truncation required, we might get here due to recursive calls.
if (SrcVT == DstVT)
return In;
// We only support vector truncation to 128bits or greater from a
// 256bits or greater source.
if ((DstVT.getSizeInBits() % 128) != 0)
return SDValue();
if ((SrcVT.getSizeInBits() % 256) != 0)
return SDValue();
unsigned NumElems = SrcVT.getVectorNumElements();
assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
assert(SrcVT.getSizeInBits() > DstVT.getSizeInBits() && "Illegal truncation");
EVT PackedSVT =
EVT::getIntegerVT(*DAG.getContext(), SrcVT.getScalarSizeInBits() / 2);
// Extract lower/upper subvectors.
unsigned NumSubElts = NumElems / 2;
unsigned SrcSizeInBits = SrcVT.getSizeInBits();
SDValue Lo = extractSubVector(In, 0 * NumSubElts, DAG, DL, SrcSizeInBits / 2);
SDValue Hi = extractSubVector(In, 1 * NumSubElts, DAG, DL, SrcSizeInBits / 2);
// 256bit -> 128bit truncate - PACKSS lower/upper 128-bit subvectors.
if (SrcVT.is256BitVector()) {
Lo = DAG.getBitcast(MVT::v8i16, Lo);
Hi = DAG.getBitcast(MVT::v8i16, Hi);
SDValue Res = DAG.getNode(X86ISD::PACKSS, DL, MVT::v16i8, Lo, Hi);
return DAG.getBitcast(DstVT, Res);
}
// AVX2: 512bit -> 256bit truncate - PACKSS lower/upper 256-bit subvectors.
// AVX2: 512bit -> 128bit truncate - PACKSS(PACKSS, PACKSS).
if (SrcVT.is512BitVector() && Subtarget.hasInt256()) {
Lo = DAG.getBitcast(MVT::v16i16, Lo);
Hi = DAG.getBitcast(MVT::v16i16, Hi);
SDValue Res = DAG.getNode(X86ISD::PACKSS, DL, MVT::v32i8, Lo, Hi);
// 256-bit PACKSS(ARG0, ARG1) leaves us with ((LO0,LO1),(HI0,HI1)),
// so we need to shuffle to get ((LO0,HI0),(LO1,HI1)).
Res = DAG.getBitcast(MVT::v4i64, Res);
Res = DAG.getVectorShuffle(MVT::v4i64, DL, Res, Res, {0, 2, 1, 3});
if (DstVT.is256BitVector())
return DAG.getBitcast(DstVT, Res);
// If 512bit -> 128bit truncate another stage.
EVT PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems);
Res = DAG.getBitcast(PackedVT, Res);
return truncateVectorCompareWithPACKSS(DstVT, Res, DL, DAG, Subtarget);
}
// Recursively pack lower/upper subvectors, concat result and pack again.
assert(SrcVT.getSizeInBits() >= 512 && "Expected 512-bit vector or greater");
EVT PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems / 2);
Lo = truncateVectorCompareWithPACKSS(PackedVT, Lo, DL, DAG, Subtarget);
Hi = truncateVectorCompareWithPACKSS(PackedVT, Hi, DL, DAG, Subtarget);
PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems);
SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi);
return truncateVectorCompareWithPACKSS(DstVT, Res, DL, DAG, Subtarget);
}
static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
@ -14198,6 +14275,23 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
DAG.getNode(X86ISD::VSEXT, DL, MVT::v16i32, In));
return DAG.getNode(X86ISD::VTRUNC, DL, VT, In);
}
// Truncate with PACKSS if we are trucating a vector comparison result.
// TODO: We should be able to support other operations as long as we
// we are saturating+packing zero/all bits only.
auto IsPackableComparison = [](SDValue V) {
unsigned Opcode = V.getOpcode();
return (Opcode == X86ISD::PCMPGT || Opcode == X86ISD::PCMPEQ ||
Opcode == X86ISD::CMPP);
};
if (IsPackableComparison(In) ||
(In.getOpcode() == ISD::CONCAT_VECTORS &&
std::all_of(In->op_begin(), In->op_end(), IsPackableComparison))) {
if (SDValue V = truncateVectorCompareWithPACKSS(VT, In, DL, DAG, Subtarget))
return V;
}
if ((VT == MVT::v4i32) && (InVT == MVT::v4i64)) {
// On AVX2, v4i64 -> v4i32 becomes VPERMD.
if (Subtarget.hasInt256()) {
@ -29652,6 +29746,45 @@ static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
/// This function transforms vector truncation of comparison results from
/// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS operations.
static SDValue combineVectorCompareTruncation(SDNode *N, SDLoc &DL,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// AVX512 has fast truncate.
if (Subtarget.hasAVX512())
return SDValue();
if (!N->getValueType(0).isVector() || !N->getValueType(0).isSimple())
return SDValue();
// TODO: we should be able to support sources other than compares as long
// as we are saturating+packing zero/all bits only.
SDValue In = N->getOperand(0);
if (In.getOpcode() != ISD::SETCC || !In.getValueType().isSimple())
return SDValue();
MVT VT = N->getValueType(0).getSimpleVT();
MVT SVT = VT.getScalarType();
MVT InVT = In.getValueType().getSimpleVT();
MVT InSVT = InVT.getScalarType();
assert(DAG.getTargetLoweringInfo().getBooleanContents(InVT) ==
llvm::TargetLoweringBase::ZeroOrNegativeOneBooleanContent &&
"Expected comparison result to be zero/all bits");
// Check we have a truncation suited for PACKSS.
if (!VT.is128BitVector() && !VT.is256BitVector())
return SDValue();
if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32)
return SDValue();
if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64)
return SDValue();
return truncateVectorCompareWithPACKSS(VT, In, DL, DAG, Subtarget);
}
static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
@ -29670,6 +29803,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(X86ISD::MMX_MOVD2W, DL, MVT::i32, BCSrc);
}
// Try to truncate vector comparison results with PACKSS.
if (SDValue V = combineVectorCompareTruncation(N, DL, DAG, Subtarget))
return V;
return combineVectorTruncation(N, DAG, Subtarget);
}

View File

@ -13,11 +13,8 @@ define <8 x i16> @pr25080(<8 x i32> %a) {
; AVX-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX-NEXT: vpxor %xmm2, %xmm2, %xmm2
; AVX-NEXT: vpcmpeqd %xmm2, %xmm1, %xmm1
; AVX-NEXT: vmovdqa {{.*#+}} xmm3 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
; AVX-NEXT: vpshufb %xmm3, %xmm1, %xmm1
; AVX-NEXT: vpcmpeqd %xmm2, %xmm0, %xmm0
; AVX-NEXT: vpshufb %xmm3, %xmm0, %xmm0
; AVX-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
; AVX-NEXT: vpacksswb %xmm1, %xmm0, %xmm0
; AVX-NEXT: vpor {{.*}}(%rip), %xmm0, %xmm0
; AVX-NEXT: vpsllw $15, %xmm0, %xmm0
; AVX-NEXT: vpsraw $15, %xmm0, %xmm0

File diff suppressed because it is too large Load Diff