[SVE][CodeGen] Extend index of masked gathers

This patch changes performMSCATTERCombine to also promote the indices of
masked gathers where the element type is i8 or i16, and adds various tests
for gathers with illegal types.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D91433
This commit is contained in:
Kerry McLaughlin 2020-12-10 11:45:45 +00:00
parent 2fc4afda0f
commit abe7775f5a
2 changed files with 46 additions and 21 deletions

View File

@ -849,6 +849,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->supportsAddressTopByteIgnored())
setTargetDAGCombine(ISD::LOAD);
setTargetDAGCombine(ISD::MGATHER);
setTargetDAGCombine(ISD::MSCATTER);
setTargetDAGCombine(ISD::MUL);
@ -14063,20 +14064,19 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}
static SDValue performMSCATTERCombine(SDNode *N,
static SDValue performMaskedGatherScatterCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
assert(MSC && "Can only combine scatter store nodes");
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
assert(MGS && "Can only combine gather load or scatter store nodes");
SDLoc DL(MSC);
SDValue Chain = MSC->getChain();
SDValue Scale = MSC->getScale();
SDValue Index = MSC->getIndex();
SDValue Data = MSC->getValue();
SDValue Mask = MSC->getMask();
SDValue BasePtr = MSC->getBasePtr();
ISD::MemIndexType IndexType = MSC->getIndexType();
SDLoc DL(MGS);
SDValue Chain = MGS->getChain();
SDValue Scale = MGS->getScale();
SDValue Index = MGS->getIndex();
SDValue Mask = MGS->getMask();
SDValue BasePtr = MGS->getBasePtr();
ISD::MemIndexType IndexType = MGS->getIndexType();
EVT IdxVT = Index.getValueType();
@ -14086,16 +14086,27 @@ static SDValue performMSCATTERCombine(SDNode *N,
if ((IdxVT.getVectorElementType() == MVT::i8) ||
(IdxVT.getVectorElementType() == MVT::i16)) {
EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
if (MSC->isIndexSigned())
if (MGS->isIndexSigned())
Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
else
Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
MSC->getMemoryVT(), DL, Ops,
MSC->getMemOperand(), IndexType,
MSC->isTruncatingStore());
if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
SDValue PassThru = MGT->getPassThru();
SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
PassThru.getValueType(), DL, Ops,
MGT->getMemOperand(),
MGT->getIndexType(), MGT->getExtensionType());
} else {
auto *MSC = cast<MaskedScatterSDNode>(MGS);
SDValue Data = MSC->getValue();
SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
MSC->getMemoryVT(), DL, Ops,
MSC->getMemOperand(), IndexType,
MSC->isTruncatingStore());
}
}
}
@ -15072,9 +15083,6 @@ static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG,
static SDValue
performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
if (DCI.isBeforeLegalizeOps())
return SDValue();
SDLoc DL(N);
SDValue Src = N->getOperand(0);
unsigned Opc = Src->getOpcode();
@ -15109,6 +15117,9 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
}
if (DCI.isBeforeLegalizeOps())
return SDValue();
if (!EnableCombineMGatherIntrinsics)
return SDValue();
@ -15296,8 +15307,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
break;
case ISD::STORE:
return performSTORECombine(N, DCI, DAG, Subtarget);
case ISD::MGATHER:
case ISD::MSCATTER:
return performMSCATTERCombine(N, DCI, DAG);
return performMaskedGatherScatterCombine(N, DCI, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:

View File

@ -54,6 +54,19 @@ define <vscale x 2 x i32> @masked_gather_nxv2i32(<vscale x 2 x i32*> %ptrs, <vsc
ret <vscale x 2 x i32> %data
}
; Code generate the worst case scenario when all vector types are legal.
define <vscale x 16 x i8> @masked_gather_nxv16i8(i8* %base, <vscale x 16 x i8> %indices, <vscale x 16 x i1> %mask) {
; CHECK-LABEL: masked_gather_nxv16i8:
; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
; CHECK: ret
%ptrs = getelementptr i8, i8* %base, <vscale x 16 x i8> %indices
%data = call <vscale x 16 x i8> @llvm.masked.gather.nxv16i8(<vscale x 16 x i8*> %ptrs, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
ret <vscale x 16 x i8> %data
}
; Code generate the worst case scenario when all vector types are illegal.
define <vscale x 32 x i32> @masked_gather_nxv32i32(i32* %base, <vscale x 32 x i32> %indices, <vscale x 32 x i1> %mask) {
; CHECK-LABEL: masked_gather_nxv32i32: