mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-27 23:51:56 +00:00
[SVE][CodeGen] Extend index of masked gathers
This patch changes performMSCATTERCombine to also promote the indices of masked gathers where the element type is i8 or i16, and adds various tests for gathers with illegal types. Reviewed By: sdesmalen Differential Revision: https://reviews.llvm.org/D91433
This commit is contained in:
parent
2fc4afda0f
commit
abe7775f5a
@ -849,6 +849,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
|
||||
if (Subtarget->supportsAddressTopByteIgnored())
|
||||
setTargetDAGCombine(ISD::LOAD);
|
||||
|
||||
setTargetDAGCombine(ISD::MGATHER);
|
||||
setTargetDAGCombine(ISD::MSCATTER);
|
||||
|
||||
setTargetDAGCombine(ISD::MUL);
|
||||
@ -14063,20 +14064,19 @@ static SDValue performSTORECombine(SDNode *N,
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
static SDValue performMSCATTERCombine(SDNode *N,
|
||||
static SDValue performMaskedGatherScatterCombine(SDNode *N,
|
||||
TargetLowering::DAGCombinerInfo &DCI,
|
||||
SelectionDAG &DAG) {
|
||||
MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
|
||||
assert(MSC && "Can only combine scatter store nodes");
|
||||
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
|
||||
assert(MGS && "Can only combine gather load or scatter store nodes");
|
||||
|
||||
SDLoc DL(MSC);
|
||||
SDValue Chain = MSC->getChain();
|
||||
SDValue Scale = MSC->getScale();
|
||||
SDValue Index = MSC->getIndex();
|
||||
SDValue Data = MSC->getValue();
|
||||
SDValue Mask = MSC->getMask();
|
||||
SDValue BasePtr = MSC->getBasePtr();
|
||||
ISD::MemIndexType IndexType = MSC->getIndexType();
|
||||
SDLoc DL(MGS);
|
||||
SDValue Chain = MGS->getChain();
|
||||
SDValue Scale = MGS->getScale();
|
||||
SDValue Index = MGS->getIndex();
|
||||
SDValue Mask = MGS->getMask();
|
||||
SDValue BasePtr = MGS->getBasePtr();
|
||||
ISD::MemIndexType IndexType = MGS->getIndexType();
|
||||
|
||||
EVT IdxVT = Index.getValueType();
|
||||
|
||||
@ -14086,16 +14086,27 @@ static SDValue performMSCATTERCombine(SDNode *N,
|
||||
if ((IdxVT.getVectorElementType() == MVT::i8) ||
|
||||
(IdxVT.getVectorElementType() == MVT::i16)) {
|
||||
EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
|
||||
if (MSC->isIndexSigned())
|
||||
if (MGS->isIndexSigned())
|
||||
Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
|
||||
else
|
||||
Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
|
||||
|
||||
SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
|
||||
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
|
||||
MSC->getMemoryVT(), DL, Ops,
|
||||
MSC->getMemOperand(), IndexType,
|
||||
MSC->isTruncatingStore());
|
||||
if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
|
||||
SDValue PassThru = MGT->getPassThru();
|
||||
SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
|
||||
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
|
||||
PassThru.getValueType(), DL, Ops,
|
||||
MGT->getMemOperand(),
|
||||
MGT->getIndexType(), MGT->getExtensionType());
|
||||
} else {
|
||||
auto *MSC = cast<MaskedScatterSDNode>(MGS);
|
||||
SDValue Data = MSC->getValue();
|
||||
SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
|
||||
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
|
||||
MSC->getMemoryVT(), DL, Ops,
|
||||
MSC->getMemOperand(), IndexType,
|
||||
MSC->isTruncatingStore());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -15072,9 +15083,6 @@ static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG,
|
||||
static SDValue
|
||||
performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
|
||||
SelectionDAG &DAG) {
|
||||
if (DCI.isBeforeLegalizeOps())
|
||||
return SDValue();
|
||||
|
||||
SDLoc DL(N);
|
||||
SDValue Src = N->getOperand(0);
|
||||
unsigned Opc = Src->getOpcode();
|
||||
@ -15109,6 +15117,9 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
|
||||
return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
|
||||
}
|
||||
|
||||
if (DCI.isBeforeLegalizeOps())
|
||||
return SDValue();
|
||||
|
||||
if (!EnableCombineMGatherIntrinsics)
|
||||
return SDValue();
|
||||
|
||||
@ -15296,8 +15307,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
|
||||
break;
|
||||
case ISD::STORE:
|
||||
return performSTORECombine(N, DCI, DAG, Subtarget);
|
||||
case ISD::MGATHER:
|
||||
case ISD::MSCATTER:
|
||||
return performMSCATTERCombine(N, DCI, DAG);
|
||||
return performMaskedGatherScatterCombine(N, DCI, DAG);
|
||||
case AArch64ISD::BRCOND:
|
||||
return performBRCONDCombine(N, DCI, DAG);
|
||||
case AArch64ISD::TBNZ:
|
||||
|
@ -54,6 +54,19 @@ define <vscale x 2 x i32> @masked_gather_nxv2i32(<vscale x 2 x i32*> %ptrs, <vsc
|
||||
ret <vscale x 2 x i32> %data
|
||||
}
|
||||
|
||||
; Code generate the worst case scenario when all vector types are legal.
|
||||
define <vscale x 16 x i8> @masked_gather_nxv16i8(i8* %base, <vscale x 16 x i8> %indices, <vscale x 16 x i1> %mask) {
|
||||
; CHECK-LABEL: masked_gather_nxv16i8:
|
||||
; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
|
||||
; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
|
||||
; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
|
||||
; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw]
|
||||
; CHECK: ret
|
||||
%ptrs = getelementptr i8, i8* %base, <vscale x 16 x i8> %indices
|
||||
%data = call <vscale x 16 x i8> @llvm.masked.gather.nxv16i8(<vscale x 16 x i8*> %ptrs, i32 1, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
|
||||
ret <vscale x 16 x i8> %data
|
||||
}
|
||||
|
||||
; Code generate the worst case scenario when all vector types are illegal.
|
||||
define <vscale x 32 x i32> @masked_gather_nxv32i32(i32* %base, <vscale x 32 x i32> %indices, <vscale x 32 x i1> %mask) {
|
||||
; CHECK-LABEL: masked_gather_nxv32i32:
|
||||
|
Loading…
Reference in New Issue
Block a user