[DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division

X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)

In the motivating case from PR46406:
https://bugs.llvm.org/show_bug.cgi?id=46406
...this is restoring the sequence that was originally in the source code.
We extracted a term from within the sqrt because we do not know in
instcombine whether a target will expand a sqrt call.
Note: we could say that the transform in IR should be restricted, but
that would not solve the problem if the source was originally in the
pattern shown here.

This is a gray area for fast-math-flag requirements. I think we should at
least check fast-math-flags on the fdiv and fmul because I view this
transform as 2 pieces: reassociate the fmul operands and form reciprocal
from the fdiv (as with the existing transform). We could argue that the
sqrt also needs FMF, but that was not required before, so we should change
that in a follow-up patch if that seems better.

We don't currently have a way to check that the target will produce a sqrt
or recip estimate without actually creating nodes (the APIs are SDValue
getSqrtEstimate() and SDValue getRecipEstimate()), so we clean up
speculatively created nodes if we are not able to create an estimate.
The x86 test with doubles verifies that we are not changing a test with
no estimate sequence.

Differential Revision: https://reviews.llvm.org/D82716
This commit is contained in:
Sanjay Patel 2020-07-06 18:03:55 -04:00
parent 4029f8ede4
commit ea71ba11ab
2 changed files with 81 additions and 56 deletions

View File

@ -13232,6 +13232,24 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
Y = N1.getOperand(0);
}
if (Sqrt.getNode()) {
// If the other multiply operand is known positive, pull it into the
// sqrt. That will eliminate the division if we convert to an estimate:
// X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
// TODO: Also fold the case where A == Z (fabs is missing).
if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() &&
Y.getOpcode() == ISD::FABS && Y.hasOneUse()) {
SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0),
Y.getOperand(0), Flags);
SDValue AAZ =
DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
// Estimate creation failed. Clean up speculatively created nodes.
recursivelyDeleteUnusedNodes(AAZ.getNode());
}
// We found a FSQRT, so try to make this fold:
// X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {

View File

@ -618,46 +618,47 @@ define <16 x float> @v16f32_estimate(<16 x float> %x) #1 {
ret <16 x float> %div
}
; x / (fabs(y) * sqrt(z))
; x / (fabs(y) * sqrt(z)) --> x * rsqrt(y*y*z)
define float @div_sqrt_fabs_f32(float %x, float %y, float %z) {
; SSE-LABEL: div_sqrt_fabs_f32:
; SSE: # %bb.0:
; SSE-NEXT: rsqrtss %xmm2, %xmm3
; SSE-NEXT: mulss %xmm3, %xmm2
; SSE-NEXT: mulss %xmm3, %xmm2
; SSE-NEXT: addss {{.*}}(%rip), %xmm2
; SSE-NEXT: mulss {{.*}}(%rip), %xmm3
; SSE-NEXT: mulss %xmm2, %xmm3
; SSE-NEXT: andps {{.*}}(%rip), %xmm1
; SSE-NEXT: divss %xmm1, %xmm3
; SSE-NEXT: mulss %xmm3, %xmm0
; SSE-NEXT: mulss %xmm1, %xmm1
; SSE-NEXT: mulss %xmm2, %xmm1
; SSE-NEXT: xorps %xmm2, %xmm2
; SSE-NEXT: rsqrtss %xmm1, %xmm2
; SSE-NEXT: mulss %xmm2, %xmm1
; SSE-NEXT: mulss %xmm2, %xmm1
; SSE-NEXT: addss {{.*}}(%rip), %xmm1
; SSE-NEXT: mulss {{.*}}(%rip), %xmm2
; SSE-NEXT: mulss %xmm0, %xmm2
; SSE-NEXT: mulss %xmm1, %xmm2
; SSE-NEXT: movaps %xmm2, %xmm0
; SSE-NEXT: retq
;
; AVX1-LABEL: div_sqrt_fabs_f32:
; AVX1: # %bb.0:
; AVX1-NEXT: vrsqrtss %xmm2, %xmm2, %xmm3
; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2
; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2
; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm2, %xmm2
; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm3, %xmm3
; AVX1-NEXT: vmulss %xmm2, %xmm3, %xmm2
; AVX1-NEXT: vandps {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vdivss %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulss %xmm1, %xmm0, %xmm0
; AVX1-NEXT: vmulss %xmm1, %xmm1, %xmm1
; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
; AVX1-NEXT: vmulss %xmm0, %xmm2, %xmm0
; AVX1-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX1-NEXT: retq
;
; AVX512-LABEL: div_sqrt_fabs_f32:
; AVX512: # %bb.0:
; AVX512-NEXT: vrsqrtss %xmm2, %xmm2, %xmm3
; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2
; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + mem
; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm3, %xmm3
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [NaN,NaN,NaN,NaN]
; AVX512-NEXT: vmulss %xmm2, %xmm3, %xmm2
; AVX512-NEXT: vandps %xmm4, %xmm1, %xmm1
; AVX512-NEXT: vdivss %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm0
; AVX512-NEXT: vmulss %xmm1, %xmm1, %xmm1
; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX512-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1
; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm1 = (xmm2 * xmm1) + mem
; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
; AVX512-NEXT: vmulss %xmm0, %xmm2, %xmm0
; AVX512-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX512-NEXT: retq
%s = call fast float @llvm.sqrt.f32(float %z)
%a = call fast float @llvm.fabs.f32(float %y)
@ -666,47 +667,46 @@ define float @div_sqrt_fabs_f32(float %x, float %y, float %z) {
ret float %d
}
; x / (fabs(y) * sqrt(z))
; x / (fabs(y) * sqrt(z)) --> x * rsqrt(y*y*z)
define <4 x float> @div_sqrt_fabs_v4f32(<4 x float> %x, <4 x float> %y, <4 x float> %z) {
; SSE-LABEL: div_sqrt_fabs_v4f32:
; SSE: # %bb.0:
; SSE-NEXT: rsqrtps %xmm2, %xmm3
; SSE-NEXT: mulps %xmm3, %xmm2
; SSE-NEXT: mulps %xmm3, %xmm2
; SSE-NEXT: addps {{.*}}(%rip), %xmm2
; SSE-NEXT: mulps {{.*}}(%rip), %xmm3
; SSE-NEXT: mulps %xmm2, %xmm3
; SSE-NEXT: andps {{.*}}(%rip), %xmm1
; SSE-NEXT: divps %xmm1, %xmm3
; SSE-NEXT: mulps %xmm3, %xmm0
; SSE-NEXT: mulps %xmm1, %xmm1
; SSE-NEXT: mulps %xmm2, %xmm1
; SSE-NEXT: rsqrtps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm2, %xmm1
; SSE-NEXT: mulps %xmm2, %xmm1
; SSE-NEXT: addps {{.*}}(%rip), %xmm1
; SSE-NEXT: mulps {{.*}}(%rip), %xmm2
; SSE-NEXT: mulps %xmm1, %xmm2
; SSE-NEXT: mulps %xmm2, %xmm0
; SSE-NEXT: retq
;
; AVX1-LABEL: div_sqrt_fabs_v4f32:
; AVX1: # %bb.0:
; AVX1-NEXT: vrsqrtps %xmm2, %xmm3
; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2
; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2
; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm2, %xmm2
; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm3, %xmm3
; AVX1-NEXT: vmulps %xmm2, %xmm3, %xmm2
; AVX1-NEXT: vandps {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vdivps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulps %xmm1, %xmm1, %xmm1
; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vrsqrtps %xmm1, %xmm2
; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm2, %xmm2
; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX1-NEXT: retq
;
; AVX512-LABEL: div_sqrt_fabs_v4f32:
; AVX512: # %bb.0:
; AVX512-NEXT: vrsqrtps %xmm2, %xmm3
; AVX512-NEXT: vmulps %xmm3, %xmm2, %xmm2
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm4 = (xmm3 * xmm2) + xmm4
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
; AVX512-NEXT: vmulps %xmm2, %xmm3, %xmm2
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [NaN,NaN,NaN,NaN]
; AVX512-NEXT: vmulps %xmm4, %xmm2, %xmm2
; AVX512-NEXT: vandps %xmm3, %xmm1, %xmm1
; AVX512-NEXT: vdivps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vmulps %xmm1, %xmm1, %xmm1
; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX512-NEXT: vrsqrtps %xmm1, %xmm2
; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm3 = (xmm2 * xmm1) + xmm3
; AVX512-NEXT: vbroadcastss {{.*#+}} xmm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX512-NEXT: vmulps %xmm3, %xmm1, %xmm1
; AVX512-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
%s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %z)
@ -716,6 +716,11 @@ define <4 x float> @div_sqrt_fabs_v4f32(<4 x float> %x, <4 x float> %y, <4 x flo
ret <4 x float> %d
}
; This has 'arcp' but does not have 'reassoc' FMF.
; We allow converting the sqrt to an estimate, but
; do not pull the divisor into the estimate.
; x / (fabs(y) * sqrt(z)) --> x * rsqrt(z) / fabs(y)
define <4 x float> @div_sqrt_fabs_v4f32_fmf(<4 x float> %x, <4 x float> %y, <4 x float> %z) {
; SSE-LABEL: div_sqrt_fabs_v4f32_fmf:
; SSE: # %bb.0:
@ -765,6 +770,8 @@ define <4 x float> @div_sqrt_fabs_v4f32_fmf(<4 x float> %x, <4 x float> %y, <4 x
ret <4 x float> %d
}
; No estimates for f64, so do not convert fabs into an fmul.
define double @div_sqrt_fabs_f64(double %x, double %y, double %z) {
; SSE-LABEL: div_sqrt_fabs_f64:
; SSE: # %bb.0: