[AArch64] Optimize instruction selection for certain vector shuffles

This patch adds code to recognize vector shuffles which can be
represented as VDUP (splat) of a vector lane with of a different
(wider) type than the original vector lane type.

For example:
    shufflevector <4 x i16> %v, <4 x i16> undef, <4 x i32> <i32 0, i32 1, i32 0, i32 1>
is essentially:
    shufflevector <2 x i32> %v, <2 x i32> undef, <2 x i32> <i32 0, i32 0>

Such patterns are generated by the SelectionDAG machinery in some cases
(see DAGCombiner::visitBITCAST in DAGCombiner.cpp, the "Remove double
bitcasts from shuffles" part).

Reviewed By: dmgreen

Differential Revision: https://reviews.llvm.org/D86225
This commit is contained in:
Mikhail Maltsev 2020-08-27 11:06:45 +01:00
parent 368c7592ca
commit 63df72fb38
6 changed files with 274 additions and 54 deletions

View File

@ -7381,6 +7381,81 @@ static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
return true;
}
/// Check if a vector shuffle corresponds to a DUP instructions with a larger
/// element width than the vector lane type. If that is the case the function
/// returns true and writes the value of the DUP instruction lane operand into
/// DupLaneOp
static bool isWideDUPMask(ArrayRef<int> M, EVT VT, unsigned BlockSize,
unsigned &DupLaneOp) {
assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
"Only possible block sizes for wide DUP are: 16, 32, 64");
if (BlockSize <= VT.getScalarSizeInBits())
return false;
if (BlockSize % VT.getScalarSizeInBits() != 0)
return false;
if (VT.getSizeInBits() % BlockSize != 0)
return false;
size_t SingleVecNumElements = VT.getVectorNumElements();
size_t NumEltsPerBlock = BlockSize / VT.getScalarSizeInBits();
size_t NumBlocks = VT.getSizeInBits() / BlockSize;
// We are looking for masks like
// [0, 1, 0, 1] or [2, 3, 2, 3] or [4, 5, 6, 7, 4, 5, 6, 7] where any element
// might be replaced by 'undefined'. BlockIndices will eventually contain
// lane indices of the duplicated block (i.e. [0, 1], [2, 3] and [4, 5, 6, 7]
// for the above examples)
SmallVector<int, 8> BlockElts(NumEltsPerBlock, -1);
for (size_t BlockIndex = 0; BlockIndex < NumBlocks; BlockIndex++)
for (size_t I = 0; I < NumEltsPerBlock; I++) {
int Elt = M[BlockIndex * NumEltsPerBlock + I];
if (Elt < 0)
continue;
// For now we don't support shuffles that use the second operand
if ((unsigned)Elt >= SingleVecNumElements)
return false;
if (BlockElts[I] < 0)
BlockElts[I] = Elt;
else if (BlockElts[I] != Elt)
return false;
}
// We found a candidate block (possibly with some undefs). It must be a
// sequence of consecutive integers starting with a value divisible by
// NumEltsPerBlock with some values possibly replaced by undef-s.
// Find first non-undef element
auto FirstRealEltIter = find_if(BlockElts, [](int Elt) { return Elt >= 0; });
assert(FirstRealEltIter != BlockElts.end() &&
"Shuffle with all-undefs must have been caught by previous cases, "
"e.g. isSplat()");
if (FirstRealEltIter == BlockElts.end()) {
DupLaneOp = 0;
return true;
}
// Index of FirstRealElt in BlockElts
size_t FirstRealIndex = FirstRealEltIter - BlockElts.begin();
if ((unsigned)*FirstRealEltIter < FirstRealIndex)
return false;
// BlockElts[0] must have the following value if it isn't undef:
size_t Elt0 = *FirstRealEltIter - FirstRealIndex;
// Check the first element
if (Elt0 % NumEltsPerBlock != 0)
return false;
// Check that the sequence indeed consists of consecutive integers (modulo
// undefs)
for (size_t I = 0; I < NumEltsPerBlock; I++)
if (BlockElts[I] >= 0 && (unsigned)BlockElts[I] != Elt0 + I)
return false;
DupLaneOp = Elt0 / NumEltsPerBlock;
return true;
}
// check if an EXT instruction can handle the shuffle mask when the
// vector sources of the shuffle are different.
static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
@ -7814,6 +7889,60 @@ static unsigned getDUPLANEOp(EVT EltType) {
llvm_unreachable("Invalid vector element type?");
}
static SDValue constructDup(SDValue V, int Lane, SDLoc dl, EVT VT,
unsigned Opcode, SelectionDAG &DAG) {
// Try to eliminate a bitcasted extract subvector before a DUPLANE.
auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) {
// Match: dup (bitcast (extract_subv X, C)), LaneC
if (BitCast.getOpcode() != ISD::BITCAST ||
BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR)
return false;
// The extract index must align in the destination type. That may not
// happen if the bitcast is from narrow to wide type.
SDValue Extract = BitCast.getOperand(0);
unsigned ExtIdx = Extract.getConstantOperandVal(1);
unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits();
unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth;
unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits();
if (ExtIdxInBits % CastedEltBitWidth != 0)
return false;
// Update the lane value by offsetting with the scaled extract index.
LaneC += ExtIdxInBits / CastedEltBitWidth;
// Determine the casted vector type of the wide vector input.
// dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC'
// Examples:
// dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3
// dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5
unsigned SrcVecNumElts =
Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth;
CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(),
SrcVecNumElts);
return true;
};
MVT CastVT;
if (getScaledOffsetDup(V, Lane, CastVT)) {
V = DAG.getBitcast(CastVT, V.getOperand(0).getOperand(0));
} else if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
// The lane is incremented by the index of the extract.
// Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3
Lane += V.getConstantOperandVal(1);
V = V.getOperand(0);
} else if (V.getOpcode() == ISD::CONCAT_VECTORS) {
// The lane is decremented if we are splatting from the 2nd operand.
// Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1
unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2;
Lane -= Idx * VT.getVectorNumElements() / 2;
V = WidenVector(V.getOperand(Idx), DAG);
} else if (VT.getSizeInBits() == 64) {
// Widen the operand to 128-bit register with undef.
V = WidenVector(V, DAG);
}
return DAG.getNode(Opcode, dl, VT, V, DAG.getConstant(Lane, dl, MVT::i64));
}
SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc dl(Op);
@ -7847,57 +7976,26 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
// Otherwise, duplicate from the lane of the input vector.
unsigned Opcode = getDUPLANEOp(V1.getValueType().getVectorElementType());
return constructDup(V1, Lane, dl, VT, Opcode, DAG);
}
// Try to eliminate a bitcasted extract subvector before a DUPLANE.
auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) {
// Match: dup (bitcast (extract_subv X, C)), LaneC
if (BitCast.getOpcode() != ISD::BITCAST ||
BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR)
return false;
// The extract index must align in the destination type. That may not
// happen if the bitcast is from narrow to wide type.
SDValue Extract = BitCast.getOperand(0);
unsigned ExtIdx = Extract.getConstantOperandVal(1);
unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits();
unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth;
unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits();
if (ExtIdxInBits % CastedEltBitWidth != 0)
return false;
// Update the lane value by offsetting with the scaled extract index.
LaneC += ExtIdxInBits / CastedEltBitWidth;
// Determine the casted vector type of the wide vector input.
// dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC'
// Examples:
// dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3
// dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5
unsigned SrcVecNumElts =
Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth;
CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(),
SrcVecNumElts);
return true;
};
MVT CastVT;
if (getScaledOffsetDup(V1, Lane, CastVT)) {
V1 = DAG.getBitcast(CastVT, V1.getOperand(0).getOperand(0));
} else if (V1.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
// The lane is incremented by the index of the extract.
// Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3
Lane += V1.getConstantOperandVal(1);
V1 = V1.getOperand(0);
} else if (V1.getOpcode() == ISD::CONCAT_VECTORS) {
// The lane is decremented if we are splatting from the 2nd operand.
// Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1
unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2;
Lane -= Idx * VT.getVectorNumElements() / 2;
V1 = WidenVector(V1.getOperand(Idx), DAG);
} else if (VT.getSizeInBits() == 64) {
// Widen the operand to 128-bit register with undef.
V1 = WidenVector(V1, DAG);
// Check if the mask matches a DUP for a wider element
for (unsigned LaneSize : {64U, 32U, 16U}) {
unsigned Lane = 0;
if (isWideDUPMask(ShuffleMask, VT, LaneSize, Lane)) {
unsigned Opcode = LaneSize == 64 ? AArch64ISD::DUPLANE64
: LaneSize == 32 ? AArch64ISD::DUPLANE32
: AArch64ISD::DUPLANE16;
// Cast V1 to an integer vector with required lane size
MVT NewEltTy = MVT::getIntegerVT(LaneSize);
unsigned NewEltCount = VT.getSizeInBits() / LaneSize;
MVT NewVecTy = MVT::getVectorVT(NewEltTy, NewEltCount);
V1 = DAG.getBitcast(NewVecTy, V1);
// Constuct the DUP instruction
V1 = constructDup(V1, Lane, dl, NewVecTy, Opcode, DAG);
// Cast back to the original type
return DAG.getBitcast(VT, V1);
}
return DAG.getNode(Opcode, dl, VT, V1, DAG.getConstant(Lane, dl, MVT::i64));
}
if (isREVMask(ShuffleMask, VT, 64))

View File

@ -1966,7 +1966,7 @@ define <4 x i16> @test_vadd_laneq5_i16_bitcast(<4 x i16> %a, <2 x double> %v) {
define <4 x i16> @test_vadd_lane2_i16_bitcast_bigger_aligned(<4 x i16> %a, <16 x i8> %v) {
; CHECK-LABEL: test_vadd_lane2_i16_bitcast_bigger_aligned:
; CHECK: // %bb.0:
; CHECK-NEXT: ext v1.8b, v1.8b, v0.8b, #2
; CHECK-NEXT: dup v1.4h, v1.h[2]
; CHECK-NEXT: dup v1.4h, v1.h[1]
; CHECK-NEXT: add v0.4h, v1.4h, v0.4h
; CHECK-NEXT: ret

View File

@ -14,7 +14,7 @@ entry:
define <4 x i16> @vext_6701_12(<4 x i16> %a1, <4 x i16> %a2) {
entry:
; CHECK-LABEL: vext_6701_12:
; CHECK: ext v0.8b, v0.8b, v0.8b, #4
; CHECK: dup v0.2s, v0.s[0]
%x = shufflevector <4 x i16> %a1, <4 x i16> %a2, <4 x i32> <i32 undef, i32 undef, i32 0, i32 1>
ret <4 x i16> %x
}
@ -54,7 +54,7 @@ entry:
define <4 x i16> @vext_6701_34(<4 x i16> %a1, <4 x i16> %a2) {
entry:
; CHECK-LABEL: vext_6701_34:
; CHECK: ext v0.8b, v1.8b, v0.8b, #4
; CHECK: dup v0.2s, v1.s[1]
%x = shufflevector <4 x i16> %a1, <4 x i16> %a2, <4 x i32> <i32 6, i32 7, i32 undef, i32 undef>
ret <4 x i16> %x
}

View File

@ -209,7 +209,7 @@ entry:
define <4 x i16> @test_undef_vext_s16(<4 x i16> %a) {
; CHECK-LABEL: test_undef_vext_s16:
; CHECK: ext {{v[0-9]+}}.8b, {{v[0-9]+}}.8b, {{v[0-9]+}}.8b, #{{0x4|4}}
; CHECK: dup v{{[0-9]+}}.2s, {{v[0-9]+}}.s[1]
entry:
%vext = shufflevector <4 x i16> %a, <4 x i16> undef, <4 x i32> <i32 2, i32 3, i32 4, i32 5>
ret <4 x i16> %vext

View File

@ -0,0 +1,122 @@
; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-linux-gnu -mattr=+neon | FileCheck %s
define <4 x i16> @shuffle1(<4 x i16> %v) {
; CHECK-LABEL: shuffle1:
; CHECK: dup v0.2s, v0.s[0]
; CHECK-NEXT: ret
entry:
%res = shufflevector <4 x i16> %v, <4 x i16> undef, <4 x i32> <i32 0, i32 undef, i32 0, i32 1>
ret <4 x i16> %res
}
define <4 x i16> @shuffle2(<4 x i16> %v) {
; CHECK-LABEL: shuffle2:
; CHECK: dup v0.2s, v0.s[1]
; CHECK-NEXT: ret
entry:
%res = shufflevector <4 x i16> %v, <4 x i16> undef, <4 x i32> <i32 2, i32 3, i32 undef, i32 3>
ret <4 x i16> %res
}
define <8 x i16> @shuffle3(<8 x i16> %v) {
; CHECK-LABEL: shuffle3:
; CHECK: dup v0.2d, v0.d[0]
; CHECK-NEXT: ret
entry:
%res = shufflevector <8 x i16> %v, <8 x i16> undef, <8 x i32> <i32 undef, i32 undef, i32 2, i32 3,
i32 undef, i32 1, i32 undef, i32 3>
ret <8 x i16> %res
}
define <4 x i32> @shuffle4(<4 x i32> %v) {
; CHECK-LABEL: shuffle4:
; CHECK: dup v0.2d, v0.d[0]
; CHECK-NEXT: ret
entry:
%res = shufflevector <4 x i32> %v, <4 x i32> undef, <4 x i32> <i32 0, i32 1, i32 0, i32 1>
ret <4 x i32> %res
}
define <16 x i8> @shuffle5(<16 x i8> %v) {
; CHECK-LABEL: shuffle5:
; CHECK: dup v0.4s, v0.s[2]
; CHECK-NEXT: ret
entry:
%res = shufflevector <16 x i8> %v, <16 x i8> undef, <16 x i32> <i32 8, i32 9, i32 10, i32 11,
i32 8, i32 9, i32 10, i32 11,
i32 8, i32 9, i32 10, i32 11,
i32 8, i32 9, i32 10, i32 11>
ret <16 x i8> %res
}
define <16 x i8> @shuffle6(<16 x i8> %v) {
; CHECK-LABEL: shuffle6:
; CHECK: dup v0.2d, v0.d[1]
; CHECK-NEXT: ret
entry:
%res = shufflevector <16 x i8> %v, <16 x i8> undef, <16 x i32> <i32 8, i32 9, i32 10, i32 11,
i32 12, i32 13, i32 14, i32 15,
i32 8, i32 9, i32 10, i32 11,
i32 12, i32 13, i32 14, i32 15>
ret <16 x i8> %res
}
define <8 x i8> @shuffle7(<8 x i8> %v) {
; CHECK-LABEL: shuffle7:
; CHECK: dup v0.2s, v0.s[1]
; CHECK-NEXT: ret
entry:
%res = shufflevector <8 x i8> %v, <8 x i8> undef, <8 x i32> <i32 4, i32 5, i32 6, i32 undef,
i32 undef, i32 5, i32 6, i32 undef>
ret <8 x i8> %res
}
define <8 x i8> @shuffle8(<8 x i8> %v) {
; CHECK-LABEL: shuffle8:
; CHECK: dup v0.4h, v0.h[3]
; CHECK-NEXT: ret
entry:
%res = shufflevector <8 x i8> %v, <8 x i8> undef, <8 x i32> <i32 6, i32 7, i32 6, i32 undef,
i32 undef, i32 7, i32 6, i32 undef>
ret <8 x i8> %res
}
; No blocks
define <8 x i8> @shuffle_not1(<16 x i8> %v) {
; CHECK-LABEL: shuffle_not1:
; CHECK: ext v0.16b, v0.16b, v0.16b, #2
%res = shufflevector <16 x i8> %v, <16 x i8> undef, <8 x i32> <i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9>
ret <8 x i8> %res
}
; Block is not a proper lane
define <4 x i32> @shuffle_not2(<4 x i32> %v) {
; CHECK-LABEL: shuffle_not2:
; CHECK-NOT: dup
; CHECK: ext
; CHECK: ret
entry:
%res = shufflevector <4 x i32> %v, <4 x i32> undef, <4 x i32> <i32 1, i32 2, i32 1, i32 2>
ret <4 x i32> %res
}
; Block size is equal to vector size
define <4 x i16> @shuffle_not3(<4 x i16> %v) {
; CHECK-LABEL: shuffle_not3:
; CHECK-NOT: dup
; CHECK: ret
entry:
%res = shufflevector <4 x i16> %v, <4 x i16> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i16> %res
}
; Blocks mismatch
define <8 x i8> @shuffle_not4(<8 x i8> %v) {
; CHECK-LABEL: shuffle_not4:
; CHECK-NOT: dup
; CHECK: ret
entry:
%res = shufflevector <8 x i8> %v, <8 x i8> undef, <8 x i32> <i32 4, i32 5, i32 6, i32 undef,
i32 undef, i32 5, i32 5, i32 undef>
ret <8 x i8> %res
}

View File

@ -77,7 +77,7 @@ define float @test_v16f32(<16 x float> %a) nounwind {
; CHECK-NEXT: fmaxnm v1.4s, v1.4s, v3.4s
; CHECK-NEXT: fmaxnm v0.4s, v0.4s, v2.4s
; CHECK-NEXT: fmaxnm v0.4s, v0.4s, v1.4s
; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: dup v1.2d, v0.d[1]
; CHECK-NEXT: fmaxnm v0.4s, v0.4s, v1.4s
; CHECK-NEXT: dup v1.4s, v0.s[1]
; CHECK-NEXT: fmaxnm v0.4s, v0.4s, v1.4s