AVX-512: Fixed masked load / store instruction selection for KNL.

Patterns were missing for KNL target for <8 x i32>, <8 x float> masked load/store.

This intrinsic comes with all legal types:
<8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 align, <8 x i1> %mask, <8 x float> %passThru),
but still requires lowering, because VMASKMOVPS, VMASKMOVDQU32 work with 512-bit vectors only.

All data operands should be widened to 512-bit vector.
The mask operand should be widened to v16i1 with zeroes.

Differential Revision: http://reviews.llvm.org/D15265



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@254909 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Elena Demikhovsky 2015-12-07 13:39:24 +00:00
parent eea645e49f
commit b06ff9b1e1
4 changed files with 151 additions and 32 deletions

View File

@ -244,7 +244,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
Changed = true;
return LegalizeOp(ExpandStore(Op));
}
} else if (Op.getOpcode() == ISD::MSCATTER)
} else if (Op.getOpcode() == ISD::MSCATTER || Op.getOpcode() == ISD::MSTORE)
HasVectorValue = true;
for (SDNode::value_iterator J = Node->value_begin(), E = Node->value_end();
@ -344,6 +344,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::MSCATTER:
QueryType = cast<MaskedScatterSDNode>(Node)->getValue().getValueType();
break;
case ISD::MSTORE:
QueryType = cast<MaskedStoreSDNode>(Node)->getValue().getValueType();
break;
}
switch (TLI.getOperationAction(Node->getOpcode(), QueryType)) {

View File

@ -1384,6 +1384,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setTruncStoreAction(MVT::v2i64, MVT::v2i32, Legal);
setTruncStoreAction(MVT::v4i32, MVT::v4i8, Legal);
setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal);
} else {
setOperationAction(ISD::MLOAD, MVT::v8i32, Custom);
setOperationAction(ISD::MLOAD, MVT::v8f32, Custom);
setOperationAction(ISD::MSTORE, MVT::v8i32, Custom);
setOperationAction(ISD::MSTORE, MVT::v8f32, Custom);
}
setOperationAction(ISD::TRUNCATE, MVT::i1, Custom);
setOperationAction(ISD::TRUNCATE, MVT::v16i8, Custom);
@ -1459,6 +1464,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v8i1, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v16i1, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v16i1, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v16i1, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v8i1, Custom);
setOperationAction(ISD::BUILD_VECTOR, MVT::v8i1, Custom);
@ -19685,6 +19691,47 @@ static SDValue LowerFSINCOS(SDValue Op, const X86Subtarget *Subtarget,
return DAG.getNode(ISD::MERGE_VALUES, dl, Tys, SinVal, CosVal);
}
/// Widen a vector input to a vector of NVT. The
/// input vector must have the same element type as NVT.
static SDValue ExtendToType(SDValue InOp, MVT NVT, SelectionDAG &DAG,
bool FillWithZeroes = false) {
// Check if InOp already has the right width.
MVT InVT = InOp.getSimpleValueType();
if (InVT == NVT)
return InOp;
if (InOp.isUndef())
return DAG.getUNDEF(NVT);
assert(InVT.getVectorElementType() == NVT.getVectorElementType() &&
"input and widen element type must match");
unsigned InNumElts = InVT.getVectorNumElements();
unsigned WidenNumElts = NVT.getVectorNumElements();
assert(WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0 &&
"Unexpected request for vector widening");
EVT EltVT = NVT.getVectorElementType();
SDLoc dl(InOp);
if (ISD::isBuildVectorOfConstantSDNodes(InOp.getNode()) ||
ISD::isBuildVectorOfConstantFPSDNodes(InOp.getNode())) {
SmallVector<SDValue, 16> Ops;
for (unsigned i = 0; i < InNumElts; ++i)
Ops.push_back(InOp.getOperand(i));
SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, EltVT) :
DAG.getUNDEF(EltVT);
for (unsigned i = 0; i < WidenNumElts - InNumElts; ++i)
Ops.push_back(FillVal);
return DAG.getNode(ISD::BUILD_VECTOR, dl, NVT, Ops);
}
SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, NVT) :
DAG.getUNDEF(NVT);
return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, NVT, FillVal,
InOp, DAG.getIntPtrConstant(0, dl));
}
static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
assert(Subtarget->hasAVX512() &&
@ -19714,6 +19761,62 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget *Subtarget,
return Op;
}
static SDValue LowerMLOAD(SDValue Op, const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
MaskedLoadSDNode *N = cast<MaskedLoadSDNode>(Op.getNode());
MVT VT = Op.getSimpleValueType();
SDValue Mask = N->getMask();
SDLoc dl(Op);
if (Subtarget->hasAVX512() && !Subtarget->hasVLX() &&
!VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) {
// This operation is legal for targets with VLX, but without
// VLX the vector should be widened to 512 bit
unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec);
MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
SDValue Src0 = N->getSrc0();
Src0 = ExtendToType(Src0, WideDataVT, DAG);
Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(),
N->getBasePtr(), Mask, Src0,
N->getMemoryVT(), N->getMemOperand(),
N->getExtensionType());
SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT,
NewLoad.getValue(0),
DAG.getIntPtrConstant(0, dl));
SDValue RetOps[] = {Exract, NewLoad.getValue(1)};
return DAG.getMergeValues(RetOps, dl);
}
return Op;
}
static SDValue LowerMSTORE(SDValue Op, const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
MaskedStoreSDNode *N = cast<MaskedStoreSDNode>(Op.getNode());
SDValue DataToStore = N->getValue();
MVT VT = DataToStore.getSimpleValueType();
SDValue Mask = N->getMask();
SDLoc dl(Op);
if (Subtarget->hasAVX512() && !Subtarget->hasVLX() &&
!VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) {
// This operation is legal for targets with VLX, but without
// VLX the vector should be widened to 512 bit
unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec);
MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
DataToStore = ExtendToType(DataToStore, WideDataVT, DAG);
Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(),
Mask, N->getMemoryVT(), N->getMemOperand(),
N->isTruncatingStore());
}
return Op;
}
static SDValue LowerMGATHER(SDValue Op, const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
assert(Subtarget->hasAVX512() &&
@ -19873,6 +19976,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::UMAX:
case ISD::UMIN: return LowerMINMAX(Op, DAG);
case ISD::FSINCOS: return LowerFSINCOS(Op, Subtarget, DAG);
case ISD::MLOAD: return LowerMLOAD(Op, Subtarget, DAG);
case ISD::MSTORE: return LowerMSTORE(Op, Subtarget, DAG);
case ISD::MGATHER: return LowerMGATHER(Op, Subtarget, DAG);
case ISD::MSCATTER: return LowerMSCATTER(Op, Subtarget, DAG);
case ISD::GC_TRANSITION_START:

View File

@ -2766,22 +2766,6 @@ def: Pat<(int_x86_avx512_mask_store_pd_512 addr:$ptr, (v8f64 VR512:$src),
(VMOVAPDZmrk addr:$ptr, (v8i1 (COPY_TO_REGCLASS GR8:$mask, VK8WM)),
VR512:$src)>;
let Predicates = [HasAVX512, NoVLX] in {
def: Pat<(X86mstore addr:$ptr, VK8WM:$mask, (v8f32 VR256:$src)),
(VMOVUPSZmrk addr:$ptr,
(v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)),
(INSERT_SUBREG (v16f32 (IMPLICIT_DEF)), VR256:$src, sub_ymm))>;
def: Pat<(v8f32 (masked_load addr:$ptr, VK8WM:$mask, undef)),
(v8f32 (EXTRACT_SUBREG (v16f32 (VMOVUPSZrmkz
(v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>;
def: Pat<(v8f32 (masked_load addr:$ptr, VK8WM:$mask, (v8f32 VR256:$src0))),
(v8f32 (EXTRACT_SUBREG (v16f32 (VMOVUPSZrmk
(INSERT_SUBREG (v16f32 (IMPLICIT_DEF)), VR256:$src0, sub_ymm),
(v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>;
}
defm VMOVDQA32 : avx512_alignedload_vl<0x6F, "vmovdqa32", avx512vl_i32_info,
HasAVX512>,
avx512_alignedstore_vl<0x7F, "vmovdqa32", avx512vl_i32_info,
@ -2843,17 +2827,6 @@ def : Pat<(v16i32 (vselect VK16WM:$mask, (v16i32 immAllZerosV),
(v16i32 VR512:$src))),
(VMOVDQU32Zrrkz (KNOTWrr VK16WM:$mask), VR512:$src)>;
}
// NoVLX patterns
let Predicates = [HasAVX512, NoVLX] in {
def: Pat<(X86mstore addr:$ptr, VK8WM:$mask, (v8i32 VR256:$src)),
(VMOVDQU32Zmrk addr:$ptr,
(v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)),
(INSERT_SUBREG (v16i32 (IMPLICIT_DEF)), VR256:$src, sub_ymm))>;
def: Pat<(v8i32 (masked_load addr:$ptr, VK8WM:$mask, undef)),
(v8i32 (EXTRACT_SUBREG (v16i32 (VMOVDQU32Zrmkz
(v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>;
}
// Move Int Doubleword to Packed Double Int
//

View File

@ -139,18 +139,55 @@ define <4 x double> @test10(<4 x i32> %trigger, <4 x double>* %addr, <4 x double
ret <4 x double> %res
}
; AVX2-LABEL: test11
; AVX2-LABEL: test11a
; AVX2: vmaskmovps
; AVX2: vblendvps
; SKX-LABEL: test11
; SKX: vmovaps {{.*}}{%k1}
define <8 x float> @test11(<8 x i32> %trigger, <8 x float>* %addr, <8 x float> %dst) {
; SKX-LABEL: test11a
; SKX: vmovaps (%rdi), %ymm1 {%k1}
; AVX512-LABEL: test11a
; AVX512: kshiftlw $8
; AVX512: kshiftrw $8
; AVX512: vmovups (%rdi), %zmm1 {%k1}
define <8 x float> @test11a(<8 x i32> %trigger, <8 x float>* %addr, <8 x float> %dst) {
%mask = icmp eq <8 x i32> %trigger, zeroinitializer
%res = call <8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 32, <8 x i1>%mask, <8 x float>%dst)
ret <8 x float> %res
}
; SKX-LABEL: test11b
; SKX: vmovdqu32 (%rdi), %ymm1 {%k1}
; AVX512-LABEL: test11b
; AVX512: kshiftlw $8
; AVX512: kshiftrw $8
; AVX512: vmovdqu32 (%rdi), %zmm1 {%k1}
define <8 x i32> @test11b(<8 x i1> %mask, <8 x i32>* %addr, <8 x i32> %dst) {
%res = call <8 x i32> @llvm.masked.load.v8i32(<8 x i32>* %addr, i32 4, <8 x i1>%mask, <8 x i32>%dst)
ret <8 x i32> %res
}
; SKX-LABEL: test11c
; SKX: vmovaps (%rdi), %ymm0 {%k1} {z}
; AVX512-LABEL: test11c
; AVX512: kshiftlw $8
; AVX512: kshiftrw $8
; AVX512: vmovups (%rdi), %zmm0 {%k1} {z}
define <8 x float> @test11c(<8 x i1> %mask, <8 x float>* %addr) {
%res = call <8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 32, <8 x i1> %mask, <8 x float> zeroinitializer)
ret <8 x float> %res
}
; SKX-LABEL: test11d
; SKX: vmovdqu32 (%rdi), %ymm0 {%k1} {z}
; AVX512-LABEL: test11d
; AVX512: kshiftlw $8
; AVX512: kshiftrw $8
; AVX512: vmovdqu32 (%rdi), %zmm0 {%k1} {z}
define <8 x i32> @test11d(<8 x i1> %mask, <8 x i32>* %addr) {
%res = call <8 x i32> @llvm.masked.load.v8i32(<8 x i32>* %addr, i32 4, <8 x i1> %mask, <8 x i32> zeroinitializer)
ret <8 x i32> %res
}
; AVX2-LABEL: test12
; AVX2: vpmaskmovd %ymm
@ -291,6 +328,7 @@ declare void @llvm.masked.store.v16f32(<16 x float>, <16 x float>*, i32, <16 x i
declare void @llvm.masked.store.v16f32p(<16 x float>*, <16 x float>**, i32, <16 x i1>)
declare <16 x float> @llvm.masked.load.v16f32(<16 x float>*, i32, <16 x i1>, <16 x float>)
declare <8 x float> @llvm.masked.load.v8f32(<8 x float>*, i32, <8 x i1>, <8 x float>)
declare <8 x i32> @llvm.masked.load.v8i32(<8 x i32>*, i32, <8 x i1>, <8 x i32>)
declare <4 x float> @llvm.masked.load.v4f32(<4 x float>*, i32, <4 x i1>, <4 x float>)
declare <2 x float> @llvm.masked.load.v2f32(<2 x float>*, i32, <2 x i1>, <2 x float>)
declare <8 x double> @llvm.masked.load.v8f64(<8 x double>*, i32, <8 x i1>, <8 x double>)