[SDAG] try harder to fold casts into vector compare

sext (vsetcc X, Y) --> vsetcc (zext X), (zext Y) --
(when the zexts are free and a bunch of other conditions)

We have a couple of similar folds to this already for vector selects,
but this pattern slips through because it is only a setcc.

The tests are based on the motivating case from:
https://llvm.org/PR50055
...but we need extra logic to get that example, so I've left that as
a TODO for now.

Differential Revision: https://reviews.llvm.org/D103280
This commit is contained in:
Sanjay Patel 2021-05-31 07:14:01 -04:00
parent a723ca32af
commit 434c8e013a
2 changed files with 83 additions and 75 deletions

View File

@ -10937,6 +10937,36 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
return DAG.getSExtOrTrunc(VsetCC, DL, VT);
}
}
// Try to eliminate the sext of a setcc by zexting the compare operands.
// TODO: Handle signed compare by sexting the ops.
if (!ISD::isSignedIntSetCC(CC) && N0.hasOneUse() &&
TLI.isOperationLegalOrCustom(ISD::SETCC, VT) /*&&
!TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)*/) {
// We have an unsupported narrow vector compare op that would be legal
// if extended to the destination type. See if the compare operands
// can be freely extended to the destination type.
auto IsFreeToZext = [&](SDValue V) {
if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
return true;
// Match a simple, non-extended load that can be converted to a
// legal zext-load.
// TODO: Handle more than one use if the other uses are free to zext.
// TODO: Allow widening of an existing zext-load?
return ISD::isNON_EXTLoad(V.getNode()) &&
ISD::isUNINDEXEDLoad(V.getNode()) &&
cast<LoadSDNode>(V)->isSimple() &&
TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, V.getValueType()) &&
V.hasOneUse();
};
if (IsFreeToZext(N00) && IsFreeToZext(N01)) {
SDValue Ext0 = DAG.getZExtOrTrunc(N00, DL, VT);
SDValue Ext1 = DAG.getZExtOrTrunc(N01, DL, VT);
return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
}
}
}
// sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)

View File

@ -18,31 +18,20 @@ define <8 x i16> @cmp_ne_load_const(<8 x i8>* %x) nounwind {
; SSE-NEXT: psraw $8, %xmm0
; SSE-NEXT: retq
;
; AVX2-LABEL: cmp_ne_load_const:
; AVX2: # %bb.0:
; AVX2-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero
; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX2-NEXT: vpcmpeqb %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
; AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpmovsxbw %xmm0, %xmm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cmp_ne_load_const:
; AVX512: # %bb.0:
; AVX512-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero
; AVX512-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX512-NEXT: vpcmpeqb %xmm1, %xmm0, %xmm0
; AVX512-NEXT: vpternlogq $15, %zmm0, %zmm0, %zmm0
; AVX512-NEXT: vpmovsxbw %xmm0, %xmm0
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
; AVX-LABEL: cmp_ne_load_const:
; AVX: # %bb.0:
; AVX-NEXT: vpmovzxbw {{.*#+}} xmm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
; AVX-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0
; AVX-NEXT: retq
%loadx = load <8 x i8>, <8 x i8>* %x
%icmp = icmp ne <8 x i8> %loadx, zeroinitializer
%sext = sext <8 x i1> %icmp to <8 x i16>
ret <8 x i16> %sext
}
; negative test - simple loads only
define <8 x i16> @cmp_ne_load_const_volatile(<8 x i8>* %x) nounwind {
; SSE-LABEL: cmp_ne_load_const_volatile:
; SSE: # %bb.0:
@ -80,6 +69,8 @@ define <8 x i16> @cmp_ne_load_const_volatile(<8 x i8>* %x) nounwind {
ret <8 x i16> %sext
}
; negative test - don't create extra load
define <8 x i16> @cmp_ne_load_const_extra_use1(<8 x i8>* %x) nounwind {
; SSE-LABEL: cmp_ne_load_const_extra_use1:
; SSE: # %bb.0:
@ -130,6 +121,8 @@ define <8 x i16> @cmp_ne_load_const_extra_use1(<8 x i8>* %x) nounwind {
ret <8 x i16> %sext
}
; negative test - don't create extra compare
define <8 x i16> @cmp_ne_load_const_extra_use2(<8 x i8>* %x) nounwind {
; SSE-LABEL: cmp_ne_load_const_extra_use2:
; SSE: # %bb.0:
@ -184,6 +177,8 @@ define <8 x i16> @cmp_ne_load_const_extra_use2(<8 x i8>* %x) nounwind {
ret <8 x i16> %sext
}
; negative test - not free extend
define <8 x i16> @cmp_ne_no_load_const(i64 %x) nounwind {
; SSE-LABEL: cmp_ne_no_load_const:
; SSE: # %bb.0:
@ -235,31 +230,20 @@ define <4 x i32> @cmp_ult_load_const(<4 x i8>* %x) nounwind {
; SSE-NEXT: psrad $24, %xmm0
; SSE-NEXT: retq
;
; AVX2-LABEL: cmp_ult_load_const:
; AVX2: # %bb.0:
; AVX2-NEXT: vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
; AVX2-NEXT: vpmaxub {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
; AVX2-NEXT: vpcmpeqb %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
; AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpmovsxbd %xmm0, %xmm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cmp_ult_load_const:
; AVX512: # %bb.0:
; AVX512-NEXT: vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
; AVX512-NEXT: vpmaxub {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
; AVX512-NEXT: vpcmpeqb %xmm1, %xmm0, %xmm0
; AVX512-NEXT: vpternlogq $15, %zmm0, %zmm0, %zmm0
; AVX512-NEXT: vpmovsxbd %xmm0, %xmm0
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
; AVX-LABEL: cmp_ult_load_const:
; AVX: # %bb.0:
; AVX-NEXT: vpmovzxbd {{.*#+}} xmm0 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX-NEXT: vmovdqa {{.*#+}} xmm1 = [42,214,0,255]
; AVX-NEXT: vpcmpgtd %xmm0, %xmm1, %xmm0
; AVX-NEXT: retq
%loadx = load <4 x i8>, <4 x i8>* %x
%icmp = icmp ult <4 x i8> %loadx, <i8 42, i8 -42, i8 0, i8 -1>
%sext = sext <4 x i1> %icmp to <4 x i32>
ret <4 x i32> %sext
}
; negative test - type must be legal
define <3 x i32> @cmp_ult_load_const_bad_type(<3 x i8>* %x) nounwind {
; SSE-LABEL: cmp_ult_load_const_bad_type:
; SSE: # %bb.0:
@ -299,6 +283,8 @@ define <3 x i32> @cmp_ult_load_const_bad_type(<3 x i8>* %x) nounwind {
ret <3 x i32> %sext
}
; negative test - signed cmp (TODO)
define <4 x i32> @cmp_slt_load_const(<4 x i8>* %x) nounwind {
; SSE-LABEL: cmp_slt_load_const:
; SSE: # %bb.0:
@ -338,21 +324,20 @@ define <2 x i64> @cmp_ne_zextload(<2 x i32>* %x, <2 x i32>* %y) nounwind {
;
; AVX2-LABEL: cmp_ne_zextload:
; AVX2: # %bb.0:
; AVX2-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero
; AVX2-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero
; AVX2-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm0 = mem[0],zero,mem[1],zero
; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero
; AVX2-NEXT: vpcmpeqq %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
; AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpmovsxdq %xmm0, %xmm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cmp_ne_zextload:
; AVX512: # %bb.0:
; AVX512-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero
; AVX512-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero
; AVX512-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
; AVX512-NEXT: vpmovzxdq {{.*#+}} xmm0 = mem[0],zero,mem[1],zero
; AVX512-NEXT: vpmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero
; AVX512-NEXT: vpcmpeqq %xmm1, %xmm0, %xmm0
; AVX512-NEXT: vpternlogq $15, %zmm0, %zmm0, %zmm0
; AVX512-NEXT: vpmovsxdq %xmm0, %xmm0
; AVX512-NEXT: # kill: def $xmm0 killed $xmm0 killed $zmm0
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
%loadx = load <2 x i32>, <2 x i32>* %x
@ -375,27 +360,12 @@ define <8 x i16> @cmp_ugt_zextload(<8 x i8>* %x, <8 x i8>* %y) nounwind {
; SSE-NEXT: psraw $8, %xmm0
; SSE-NEXT: retq
;
; AVX2-LABEL: cmp_ugt_zextload:
; AVX2: # %bb.0:
; AVX2-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero
; AVX2-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero
; AVX2-NEXT: vpminub %xmm1, %xmm0, %xmm1
; AVX2-NEXT: vpcmpeqb %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
; AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpmovsxbw %xmm0, %xmm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cmp_ugt_zextload:
; AVX512: # %bb.0:
; AVX512-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero
; AVX512-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero
; AVX512-NEXT: vpminub %xmm1, %xmm0, %xmm1
; AVX512-NEXT: vpcmpeqb %xmm1, %xmm0, %xmm0
; AVX512-NEXT: vpternlogq $15, %zmm0, %zmm0, %zmm0
; AVX512-NEXT: vpmovsxbw %xmm0, %xmm0
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
; AVX-LABEL: cmp_ugt_zextload:
; AVX: # %bb.0:
; AVX-NEXT: vpmovzxbw {{.*#+}} xmm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
; AVX-NEXT: vpmovzxbw {{.*#+}} xmm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
; AVX-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0
; AVX-NEXT: retq
%loadx = load <8 x i8>, <8 x i8>* %x
%loady = load <8 x i8>, <8 x i8>* %y
%icmp = icmp ugt <8 x i8> %loadx, %loady
@ -403,6 +373,8 @@ define <8 x i16> @cmp_ugt_zextload(<8 x i8>* %x, <8 x i8>* %y) nounwind {
ret <8 x i16> %sext
}
; negative test - signed cmp (TODO)
define <8 x i16> @cmp_sgt_zextload(<8 x i8>* %x, <8 x i8>* %y) nounwind {
; SSE-LABEL: cmp_sgt_zextload:
; SSE: # %bb.0:
@ -427,6 +399,9 @@ define <8 x i16> @cmp_sgt_zextload(<8 x i8>* %x, <8 x i8>* %y) nounwind {
ret <8 x i16> %sext
}
; negative test - don't change a legal op
; TODO: Or should we? We can eliminate the vpmovsxwd at the cost of a 256-bit ymm vpcmpeqw.
define <8 x i32> @cmp_ne_zextload_from_legal_op(<8 x i16>* %x, <8 x i16>* %y) {
; SSE-LABEL: cmp_ne_zextload_from_legal_op:
; SSE: # %bb.0:
@ -442,19 +417,20 @@ define <8 x i32> @cmp_ne_zextload_from_legal_op(<8 x i16>* %x, <8 x i16>* %y) {
;
; AVX2-LABEL: cmp_ne_zextload_from_legal_op:
; AVX2: # %bb.0:
; AVX2-NEXT: vmovdqa (%rdi), %xmm0
; AVX2-NEXT: vpcmpeqw (%rsi), %xmm0, %xmm0
; AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
; AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpmovsxwd %xmm0, %ymm0
; AVX2-NEXT: vpmovzxwd {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
; AVX2-NEXT: vpmovzxwd {{.*#+}} ymm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
; AVX2-NEXT: vpcmpeqd %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpcmpeqd %ymm1, %ymm1, %ymm1
; AVX2-NEXT: vpxor %ymm1, %ymm0, %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cmp_ne_zextload_from_legal_op:
; AVX512: # %bb.0:
; AVX512-NEXT: vmovdqa (%rdi), %xmm0
; AVX512-NEXT: vpcmpeqw (%rsi), %xmm0, %xmm0
; AVX512-NEXT: vpmovzxwd {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
; AVX512-NEXT: vpmovzxwd {{.*#+}} ymm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
; AVX512-NEXT: vpcmpeqd %ymm1, %ymm0, %ymm0
; AVX512-NEXT: vpternlogq $15, %zmm0, %zmm0, %zmm0
; AVX512-NEXT: vpmovsxwd %xmm0, %ymm0
; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 killed $zmm0
; AVX512-NEXT: retq
%loadx = load <8 x i16>, <8 x i16>* %x
%loady = load <8 x i16>, <8 x i16>* %y
@ -463,6 +439,8 @@ define <8 x i32> @cmp_ne_zextload_from_legal_op(<8 x i16>* %x, <8 x i16>* %y) {
ret <8 x i32> %sext
}
; negative test - extra use (TODO)
define <8 x i32> @PR50055(<8 x i8>* %src, <8 x i32>* %dst) nounwind {
; SSE-LABEL: PR50055:
; SSE: # %bb.0: