[AArch64][SVE] Add legalization support for i32/i64 vector srem/urem

Implement them on top of sdiv/udiv, similar to what we do for integer
types.

Potential future work: implementing i8/i16 srem/urem, optimizations for
constant divisors, optimizing the mul+sub to mls.

Differential Revision: https://reviews.llvm.org/D81511
This commit is contained in:
Eli Friedman 2020-06-08 16:34:15 -07:00
parent 90ad786947
commit e9d4e34ab8
8 changed files with 115 additions and 19 deletions

View File

@ -4421,6 +4421,10 @@ public:
/// only the first Count elements of the vector are used.
SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const;
/// Expand an SREM or UREM using SDIV/UDIV or SDIVREM/UDIVREM, if legal.
/// Returns true if the expansion was successful.
bool expandREM(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const;
//===--------------------------------------------------------------------===//
// Instruction Emitting Hooks
//

View File

@ -3343,26 +3343,10 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
break;
}
case ISD::UREM:
case ISD::SREM: {
EVT VT = Node->getValueType(0);
bool isSigned = Node->getOpcode() == ISD::SREM;
unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV;
unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
Tmp2 = Node->getOperand(0);
Tmp3 = Node->getOperand(1);
if (TLI.isOperationLegalOrCustom(DivRemOpc, VT)) {
SDVTList VTs = DAG.getVTList(VT, VT);
Tmp1 = DAG.getNode(DivRemOpc, dl, VTs, Tmp2, Tmp3).getValue(1);
case ISD::SREM:
if (TLI.expandREM(Node, Tmp1, DAG))
Results.push_back(Tmp1);
} else if (TLI.isOperationLegalOrCustom(DivOpc, VT)) {
// X % Y -> X-X/Y*Y
Tmp1 = DAG.getNode(DivOpc, dl, VT, Tmp2, Tmp3);
Tmp1 = DAG.getNode(ISD::MUL, dl, VT, Tmp1, Tmp3);
Tmp1 = DAG.getNode(ISD::SUB, dl, VT, Tmp2, Tmp1);
Results.push_back(Tmp1);
}
break;
}
case ISD::UDIV:
case ISD::SDIV: {
bool isSigned = Node->getOpcode() == ISD::SDIV;

View File

@ -145,6 +145,7 @@ class VectorLegalizer {
void ExpandFixedPointDiv(SDNode *Node, SmallVectorImpl<SDValue> &Results);
SDValue ExpandStrictFPOp(SDNode *Node);
void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
void ExpandREM(SDNode *Node, SmallVectorImpl<SDValue> &Results);
void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl<SDValue> &Results);
@ -867,6 +868,10 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VECREDUCE_FMIN:
Results.push_back(TLI.expandVecReduce(Node, DAG));
return;
case ISD::SREM:
case ISD::UREM:
ExpandREM(Node, Results);
return;
}
Results.push_back(DAG.UnrollVectorOp(Node));
@ -1353,6 +1358,17 @@ void VectorLegalizer::ExpandStrictFPOp(SDNode *Node,
UnrollStrictFPOp(Node, Results);
}
void VectorLegalizer::ExpandREM(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
assert((Node->getOpcode() == ISD::SREM || Node->getOpcode() == ISD::UREM) &&
"Expected REM node");
SDValue Result;
if (!TLI.expandREM(Node, Result, DAG))
Result = DAG.UnrollVectorOp(Node);
Results.push_back(Result);
}
void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
EVT VT = Node->getValueType(0);

View File

@ -7823,3 +7823,26 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
Res = DAG.getNode(ISD::ANY_EXTEND, dl, Node->getValueType(0), Res);
return Res;
}
bool TargetLowering::expandREM(SDNode *Node, SDValue &Result,
SelectionDAG &DAG) const {
EVT VT = Node->getValueType(0);
SDLoc dl(Node);
bool isSigned = Node->getOpcode() == ISD::SREM;
unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV;
unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
SDValue Dividend = Node->getOperand(0);
SDValue Divisor = Node->getOperand(1);
if (isOperationLegalOrCustom(DivRemOpc, VT)) {
SDVTList VTs = DAG.getVTList(VT, VT);
Result = DAG.getNode(DivRemOpc, dl, VTs, Dividend, Divisor).getValue(1);
return true;
} else if (isOperationLegalOrCustom(DivOpc, VT)) {
// X % Y -> X-X/Y*Y
SDValue Divide = DAG.getNode(DivOpc, dl, VT, Dividend, Divisor);
SDValue Mul = DAG.getNode(ISD::MUL, dl, VT, Divide, Divisor);
Result = DAG.getNode(ISD::SUB, dl, VT, Dividend, Mul);
return true;
}
return false;
}

View File

@ -199,6 +199,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::UADDSAT, VT, Legal);
setOperationAction(ISD::SSUBSAT, VT, Legal);
setOperationAction(ISD::USUBSAT, VT, Legal);
setOperationAction(ISD::UREM, VT, Expand);
setOperationAction(ISD::SREM, VT, Expand);
setOperationAction(ISD::SDIVREM, VT, Expand);
setOperationAction(ISD::UDIVREM, VT, Expand);
}
for (auto VT :

View File

@ -443,7 +443,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::UREM, VT, Expand);
setOperationAction(ISD::SMUL_LOHI, VT, Expand);
setOperationAction(ISD::UMUL_LOHI, VT, Expand);
setOperationAction(ISD::SDIVREM, VT, Custom);
setOperationAction(ISD::SDIVREM, VT, Expand);
setOperationAction(ISD::UDIVREM, VT, Expand);
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::VSELECT, VT, Expand);

View File

@ -210,6 +210,8 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT,
setOperationAction(ISD::SREM, VT, Expand);
setOperationAction(ISD::UREM, VT, Expand);
setOperationAction(ISD::FREM, VT, Expand);
setOperationAction(ISD::SDIVREM, VT, Expand);
setOperationAction(ISD::UDIVREM, VT, Expand);
if (!VT.isFloatingPoint() &&
VT != MVT::v2i64 && VT != MVT::v1i64)
@ -284,6 +286,8 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setOperationAction(ISD::SDIV, VT, Expand);
setOperationAction(ISD::UREM, VT, Expand);
setOperationAction(ISD::SREM, VT, Expand);
setOperationAction(ISD::UDIVREM, VT, Expand);
setOperationAction(ISD::SDIVREM, VT, Expand);
setOperationAction(ISD::CTPOP, VT, Expand);
// Vector reductions

View File

@ -59,6 +59,36 @@ define <vscale x 4 x i64> @sdiv_split_i64(<vscale x 4 x i64> %a, <vscale x 4 x i
ret <vscale x 4 x i64> %div
}
;
; SREM
;
define <vscale x 4 x i32> @srem_i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b) {
; CHECK-LABEL: srem_i32:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: mov z2.d, z0.d
; CHECK-NEXT: sdiv z2.s, p0/m, z2.s, z1.s
; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s
; CHECK-NEXT: sub z0.s, z0.s, z2.s
; CHECK-NEXT: ret
%div = srem <vscale x 4 x i32> %a, %b
ret <vscale x 4 x i32> %div
}
define <vscale x 2 x i64> @srem_i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %b) {
; CHECK-LABEL: srem_i64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: mov z2.d, z0.d
; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d
; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
; CHECK-NEXT: sub z0.d, z0.d, z2.d
; CHECK-NEXT: ret
%div = srem <vscale x 2 x i64> %a, %b
ret <vscale x 2 x i64> %div
}
;
; UDIV
;
@ -117,6 +147,37 @@ define <vscale x 4 x i64> @udiv_split_i64(<vscale x 4 x i64> %a, <vscale x 4 x i
ret <vscale x 4 x i64> %div
}
;
; UREM
;
define <vscale x 4 x i32> @urem_i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b) {
; CHECK-LABEL: urem_i32:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: mov z2.d, z0.d
; CHECK-NEXT: udiv z2.s, p0/m, z2.s, z1.s
; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s
; CHECK-NEXT: sub z0.s, z0.s, z2.s
; CHECK-NEXT: ret
%div = urem <vscale x 4 x i32> %a, %b
ret <vscale x 4 x i32> %div
}
define <vscale x 2 x i64> @urem_i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %b) {
; CHECK-LABEL: urem_i64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: mov z2.d, z0.d
; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d
; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
; CHECK-NEXT: sub z0.d, z0.d, z2.d
; CHECK-NEXT: ret
%div = urem <vscale x 2 x i64> %a, %b
ret <vscale x 2 x i64> %div
}
;
; SMIN
;