AVX-512, X86: Added lowering for shift operations for SKX.

The other changes in the LowerShift() are not functional,
just to make the code more convenient.
So, the functional changes for SKX only.

llvm-svn: 237129
This commit is contained in:
Elena Demikhovsky 2015-05-12 13:25:46 +00:00
parent c74bcc2521
commit bd877a1b65
2 changed files with 94 additions and 101 deletions

View File

@ -1507,6 +1507,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::AND, MVT::v4i32, Legal);
setOperationAction(ISD::OR, MVT::v4i32, Legal);
setOperationAction(ISD::XOR, MVT::v4i32, Legal);
setOperationAction(ISD::SRA, MVT::v2i64, Custom);
setOperationAction(ISD::SRA, MVT::v4i64, Custom);
}
// We want to custom lower some of our intrinsics.
@ -16328,6 +16330,53 @@ static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget *Subtarget,
return DAG.getMergeValues(Ops, dl);
}
// Return true if the requred (according to Opcode) shift-imm form is natively
// supported by the Subtarget
static bool SupportedVectorShiftWithImm(MVT VT, const X86Subtarget *Subtarget,
unsigned Opcode) {
if (VT.getScalarSizeInBits() < 16)
return false;
if (VT.is512BitVector() &&
(VT.getScalarSizeInBits() > 16 || Subtarget->hasBWI()))
return true;
bool LShift = VT.is128BitVector() ||
(VT.is256BitVector() && Subtarget->hasInt256());
bool AShift = LShift && (Subtarget->hasVLX() ||
(VT != MVT::v2i64 && VT != MVT::v4i64));
return (Opcode == ISD::SRA) ? AShift : LShift;
}
// The shift amount is a variable, but it is the same for all vector lanes.
// These instrcutions are defined together with shift-immediate.
static
bool SupportedVectorShiftWithBaseAmnt(MVT VT, const X86Subtarget *Subtarget,
unsigned Opcode) {
return SupportedVectorShiftWithImm(VT, Subtarget, Opcode);
}
// Return true if the requred (according to Opcode) variable-shift form is
// natively supported by the Subtarget
static bool SupportedVectorVarShift(MVT VT, const X86Subtarget *Subtarget,
unsigned Opcode) {
if (!Subtarget->hasInt256() || VT.getScalarSizeInBits() < 16)
return false;
// vXi16 supported only on AVX-512, BWI
if (VT.getScalarSizeInBits() == 16 && !Subtarget->hasBWI())
return false;
if (VT.is512BitVector() || Subtarget->hasVLX())
return true;
bool LShift = VT.is128BitVector() || VT.is256BitVector();
bool AShift = LShift && VT != MVT::v2i64 && VT != MVT::v4i64;
return (Opcode == ISD::SRA) ? AShift : LShift;
}
static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
const X86Subtarget *Subtarget) {
MVT VT = Op.getSimpleValueType();
@ -16335,26 +16384,16 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
SDValue R = Op.getOperand(0);
SDValue Amt = Op.getOperand(1);
unsigned X86Opc = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
(Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
// Optimize shl/srl/sra with constant shift amount.
if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
if (auto *ShiftConst = BVAmt->getConstantSplatNode()) {
uint64_t ShiftAmt = ShiftConst->getZExtValue();
if (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 ||
(Subtarget->hasInt256() &&
(VT == MVT::v4i64 || VT == MVT::v8i32 || VT == MVT::v16i16)) ||
(Subtarget->hasAVX512() &&
(VT == MVT::v8i64 || VT == MVT::v16i32))) {
if (Op.getOpcode() == ISD::SHL)
return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
DAG);
if (Op.getOpcode() == ISD::SRL)
return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
DAG);
if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64)
return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
DAG);
}
if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
if (VT == MVT::v16i8 || (Subtarget->hasInt256() && VT == MVT::v32i8)) {
unsigned NumElts = VT.getVectorNumElements();
@ -16435,19 +16474,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
if (ShAmt != ShiftAmt)
return SDValue();
}
switch (Op.getOpcode()) {
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
DAG);
case ISD::SRL:
return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
DAG);
case ISD::SRA:
return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
DAG);
}
return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
}
return SDValue();
@ -16460,12 +16487,13 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
SDValue R = Op.getOperand(0);
SDValue Amt = Op.getOperand(1);
if ((VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) ||
VT == MVT::v4i32 || VT == MVT::v8i16 ||
(Subtarget->hasInt256() &&
((VT == MVT::v4i64 && Op.getOpcode() != ISD::SRA) ||
VT == MVT::v8i32 || VT == MVT::v16i16)) ||
(Subtarget->hasAVX512() && (VT == MVT::v8i64 || VT == MVT::v16i32))) {
unsigned X86OpcI = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
(Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL :
(Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) {
SDValue BaseShAmt;
EVT EltVT = VT.getVectorElementType();
@ -16509,47 +16537,7 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
else if (EltVT.bitsLT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, BaseShAmt);
switch (Op.getOpcode()) {
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
switch (VT.SimpleTy) {
default: return SDValue();
case MVT::v2i64:
case MVT::v4i32:
case MVT::v8i16:
case MVT::v4i64:
case MVT::v8i32:
case MVT::v16i16:
case MVT::v16i32:
case MVT::v8i64:
return getTargetVShiftNode(X86ISD::VSHLI, dl, VT, R, BaseShAmt, DAG);
}
case ISD::SRA:
switch (VT.SimpleTy) {
default: return SDValue();
case MVT::v4i32:
case MVT::v8i16:
case MVT::v8i32:
case MVT::v16i16:
case MVT::v16i32:
case MVT::v8i64:
return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, R, BaseShAmt, DAG);
}
case ISD::SRL:
switch (VT.SimpleTy) {
default: return SDValue();
case MVT::v2i64:
case MVT::v4i32:
case MVT::v8i16:
case MVT::v4i64:
case MVT::v8i32:
case MVT::v16i16:
case MVT::v16i32:
case MVT::v8i64:
return getTargetVShiftNode(X86ISD::VSRLI, dl, VT, R, BaseShAmt, DAG);
}
}
return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, DAG);
}
}
@ -16568,18 +16556,8 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
if (Vals[j] != Amt.getOperand(i + j))
return SDValue();
}
switch (Op.getOpcode()) {
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
return DAG.getNode(X86ISD::VSHL, dl, VT, R, Op.getOperand(1));
case ISD::SRL:
return DAG.getNode(X86ISD::VSRL, dl, VT, R, Op.getOperand(1));
case ISD::SRA:
return DAG.getNode(X86ISD::VSRA, dl, VT, R, Op.getOperand(1));
}
return DAG.getNode(X86OpcV, dl, VT, R, Op.getOperand(1));
}
return SDValue();
}
@ -16599,23 +16577,9 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
if (SDValue V = LowerScalarVariableShift(Op, DAG, Subtarget))
return V;
if (Subtarget->hasAVX512() && (VT == MVT::v16i32 || VT == MVT::v8i64))
if (SupportedVectorVarShift(VT, Subtarget, Op.getOpcode()))
return Op;
// AVX2 has VPSLLV/VPSRAV/VPSRLV.
if (Subtarget->hasInt256()) {
if (Op.getOpcode() == ISD::SRL &&
(VT == MVT::v2i64 || VT == MVT::v4i32 ||
VT == MVT::v4i64 || VT == MVT::v8i32))
return Op;
if (Op.getOpcode() == ISD::SHL &&
(VT == MVT::v2i64 || VT == MVT::v4i32 ||
VT == MVT::v4i64 || VT == MVT::v8i32))
return Op;
if (Op.getOpcode() == ISD::SRA && (VT == MVT::v4i32 || VT == MVT::v8i32))
return Op;
}
// 2i64 vector logical shifts can efficiently avoid scalarization - do the
// shifts per-lane and then shuffle the partial results back together.
if (VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) {

View File

@ -1,4 +1,5 @@
;RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=knl | FileCheck %s
;RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=skx | FileCheck --check-prefix=SKX %s
;CHECK-LABEL: shift_16_i32
;CHECK: vpsrld
@ -24,6 +25,18 @@ define <8 x i64> @shift_8_i64(<8 x i64> %a) {
ret <8 x i64> %d;
}
;SKX-LABEL: shift_4_i64
;SKX: vpsrlq
;SKX: vpsllq
;SKX: vpsraq
;SKX: ret
define <4 x i64> @shift_4_i64(<4 x i64> %a) {
%b = lshr <4 x i64> %a, <i64 1, i64 1, i64 1, i64 1>
%c = shl <4 x i64> %b, <i64 12, i64 12, i64 12, i64 12>
%d = ashr <4 x i64> %c, <i64 12, i64 12, i64 12, i64 12>
ret <4 x i64> %d;
}
; CHECK-LABEL: variable_shl4
; CHECK: vpsllvq %zmm
; CHECK: ret
@ -72,6 +85,22 @@ define <8 x i64> @variable_sra2(<8 x i64> %x, <8 x i64> %y) {
ret <8 x i64> %k
}
; SKX-LABEL: variable_sra3
; SKX: vpsravq %ymm
; SKX: ret
define <4 x i64> @variable_sra3(<4 x i64> %x, <4 x i64> %y) {
%k = ashr <4 x i64> %x, %y
ret <4 x i64> %k
}
; SKX-LABEL: variable_sra4
; SKX: vpsravw %xmm
; SKX: ret
define <8 x i16> @variable_sra4(<8 x i16> %x, <8 x i16> %y) {
%k = ashr <8 x i16> %x, %y
ret <8 x i16> %k
}
; CHECK-LABEL: variable_sra01_load
; CHECK: vpsravd (%
; CHECK: ret