[X86] When removing sign extends from gather/scatter indices, make sure we handle UpdateNodeOperands finding an existing node to CSE with.

If this happens the operands aren't updated and the existing node is returned. Make sure we pass this existing node up to the DAG combiner so that a proper replacement happens. Otherwise we get stuck in an infinite loop with an unoptimized node.

llvm-svn: 338090
This commit is contained in:
Craig Topper 2018-07-27 00:00:30 +00:00
parent 136b0d96f0
commit 2ad9a263b5
2 changed files with 71 additions and 15 deletions

View File

@ -38100,12 +38100,14 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index.getOperand(0);
DAG.UpdateNodeOperands(N, NewOps);
// The original sign extend has less users, add back to worklist in case
// it needs to be removed
DCI.AddToWorklist(Index.getNode());
DCI.AddToWorklist(N);
return SDValue(N, 0);
SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
if (Res == N) {
// The original sign extend has less users, add back to worklist in
// case it needs to be removed
DCI.AddToWorklist(Index.getNode());
DCI.AddToWorklist(N);
}
return SDValue(Res, 0);
}
}
@ -38118,9 +38120,10 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index;
DAG.UpdateNodeOperands(N, NewOps);
DCI.AddToWorklist(N);
return SDValue(N, 0);
SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
if (Res == N)
DCI.AddToWorklist(N);
return SDValue(Res, 0);
}
// Try to remove zero extends from 32->64 if we know the sign bit of
@ -38131,12 +38134,14 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
if (DAG.SignBitIsZero(Index.getOperand(0))) {
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
NewOps[4] = Index.getOperand(0);
DAG.UpdateNodeOperands(N, NewOps);
// The original zero extend has less users, add back to worklist in case
// it needs to be removed
DCI.AddToWorklist(Index.getNode());
DCI.AddToWorklist(N);
return SDValue(N, 0);
SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
if (Res == N) {
// The original sign extend has less users, add back to worklist in
// case it needs to be removed
DCI.AddToWorklist(Index.getNode());
DCI.AddToWorklist(N);
}
return SDValue(Res, 0);
}
}
}

View File

@ -2928,3 +2928,54 @@ define void @test_scatter_setcc_split(double* %base, <16 x i32> %ind, <16 x i32>
call void @llvm.masked.scatter.v16f64.v16p0f64(<16 x double> %src0, <16 x double*> %gep.random, i32 4, <16 x i1> %mask)
ret void
}
; This test case previously triggered an infinite loop when the two gathers became identical after DAG combine removed the sign extend.
define <16 x float> @test_sext_cse(float* %base, <16 x i32> %ind, <16 x i32>* %foo) {
; KNL_64-LABEL: test_sext_cse:
; KNL_64: # %bb.0:
; KNL_64-NEXT: vmovaps %zmm0, (%rsi)
; KNL_64-NEXT: kxnorw %k0, %k0, %k1
; KNL_64-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
; KNL_64-NEXT: vaddps %zmm1, %zmm1, %zmm0
; KNL_64-NEXT: retq
;
; KNL_32-LABEL: test_sext_cse:
; KNL_32: # %bb.0:
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %ecx
; KNL_32-NEXT: vmovaps %zmm0, (%ecx)
; KNL_32-NEXT: kxnorw %k0, %k0, %k1
; KNL_32-NEXT: vgatherdps (%eax,%zmm0,4), %zmm1 {%k1}
; KNL_32-NEXT: vaddps %zmm1, %zmm1, %zmm0
; KNL_32-NEXT: retl
;
; SKX-LABEL: test_sext_cse:
; SKX: # %bb.0:
; SKX-NEXT: vmovaps %zmm0, (%rsi)
; SKX-NEXT: kxnorw %k0, %k0, %k1
; SKX-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
; SKX-NEXT: vaddps %zmm1, %zmm1, %zmm0
; SKX-NEXT: retq
;
; SKX_32-LABEL: test_sext_cse:
; SKX_32: # %bb.0:
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %ecx
; SKX_32-NEXT: vmovaps %zmm0, (%ecx)
; SKX_32-NEXT: kxnorw %k0, %k0, %k1
; SKX_32-NEXT: vgatherdps (%eax,%zmm0,4), %zmm1 {%k1}
; SKX_32-NEXT: vaddps %zmm1, %zmm1, %zmm0
; SKX_32-NEXT: retl
%broadcast.splatinsert = insertelement <16 x float*> undef, float* %base, i32 0
%broadcast.splat = shufflevector <16 x float*> %broadcast.splatinsert, <16 x float*> undef, <16 x i32> zeroinitializer
%sext_ind = sext <16 x i32> %ind to <16 x i64>
%gep.random = getelementptr float, <16 x float*> %broadcast.splat, <16 x i64> %sext_ind
store <16 x i32> %ind, <16 x i32>* %foo
%res = call <16 x float> @llvm.masked.gather.v16f32.v16p0f32(<16 x float*> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
%gep.random2 = getelementptr float, <16 x float*> %broadcast.splat, <16 x i32> %ind
%res2 = call <16 x float> @llvm.masked.gather.v16f32.v16p0f32(<16 x float*> %gep.random2, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
%res3 = fadd <16 x float> %res2, %res
ret <16 x float>%res3
}