[mlir][math] Improved math.atan approximation

Used the cephes numerical approximation for `math.atan`. This is a
significant accuracy improvement over the previous taylor series
approximation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D153656
This commit is contained in:
Robert Suderman 2023-06-23 16:54:48 -07:00 committed by Rob Suderman
parent 929124993a
commit 0bedb667af
3 changed files with 184 additions and 100 deletions

View File

@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include <climits>
#include <cmath>
#include <cstddef>
#include "mlir/Dialect/Arith/IR/Arith.h"
@ -171,7 +172,7 @@ static Value floatCst(ImplicitLocOpBuilder &builder, float value,
builder.getFloatAttr(elementType, value));
}
static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
}
@ -380,35 +381,76 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
ArrayRef<int64_t> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto one = broadcast(builder, f32Cst(builder, 1.0f), shape);
// Remap the problem over [0.0, 1.0] by looking at the absolute value and the
// handling symmetry.
Value abs = builder.create<math::AbsFOp>(operand);
Value reciprocal = builder.create<arith::DivFOp>(one, abs);
Value compare =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal);
Value x = builder.create<arith::SelectOp>(compare, abs, reciprocal);
auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
// When 0.66 < x <= 2.41 we do (x-1) / (x+1):
auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
Value cmp2 =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
Value addone = builder.create<arith::AddFOp>(abs, one);
Value subone = builder.create<arith::SubFOp>(abs, one);
Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
};
// Break into the <= 0.66 or > 2.41 we do x or 1/x:
auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
Value cmp1 =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
Value x = builder.create<arith::DivFOp>(xnum, xden);
Value xx = builder.create<arith::MulFOp>(x, x);
// Perform the Taylor series approximation for atan over the range
// [-1.0, 1.0].
auto n1 = broadcast(builder, f32Cst(builder, 0.14418283f), shape);
auto n2 = broadcast(builder, f32Cst(builder, -0.34999234f), shape);
auto n3 = broadcast(builder, f32Cst(builder, -0.01067831f), shape);
auto n4 = broadcast(builder, f32Cst(builder, 1.00209986f), shape);
// [0.0, 0.66].
auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
Value p = builder.create<math::FmaOp>(x, n1, n2);
p = builder.create<math::FmaOp>(x, p, n3);
p = builder.create<math::FmaOp>(x, p, n4);
p = builder.create<arith::MulFOp>(x, p);
// Apply the polynomial approximation for the numerator:
Value n = p0;
n = builder.create<math::FmaOp>(xx, n, p1);
n = builder.create<math::FmaOp>(xx, n, p2);
n = builder.create<math::FmaOp>(xx, n, p3);
n = builder.create<math::FmaOp>(xx, n, p4);
n = builder.create<arith::MulFOp>(n, xx);
// Remap the solution for over [0.0, 1.0] to [0.0, inf]
auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
Value sub = builder.create<arith::SubFOp>(halfPi, p);
Value select = builder.create<arith::SelectOp>(compare, p, sub);
// Apply the polynomial approximation for the denominator:
Value d = q0;
d = builder.create<math::FmaOp>(xx, d, q1);
d = builder.create<math::FmaOp>(xx, d, q2);
d = builder.create<math::FmaOp>(xx, d, q3);
d = builder.create<math::FmaOp>(xx, d, q4);
// Compute approximation of theta:
Value ans0 = builder.create<arith::DivFOp>(n, d);
ans0 = builder.create<math::FmaOp>(ans0, x, x);
// Correct for the input mapping's angles:
Value mpi4 = bcast(f32Cst(builder, M_PI_4));
Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
Value mpi2 = bcast(f32Cst(builder, M_PI_2));
Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
// Correct for signing of the input.
rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand);
rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
return success();
}

View File

@ -587,24 +587,50 @@ func.func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> {
}
// CHECK-LABEL: @atan_scalar
// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00
// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831
// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335
// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099
// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987
// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637
// CHECK-DAG: %[[ABS:.+]] = math.absf %arg0
// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
// CHECK-DAG: %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]]
// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
// CHECK-DAG: %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]]
// CHECK-DAG: %[[RES:.+]] = math.copysign %[[EST]], %arg0
// CHECK: return %[[RES]]
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6.600000e-01 : f32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2.41421366 : f32
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant -0.875060856 : f32
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant -16.1575375 : f32
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant -75.0085601 : f32
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant -122.886665 : f32
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant -64.8502197 : f32
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 24.8584652 : f32
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 165.027008 : f32
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 432.881073 : f32
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 485.390411 : f32
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 194.550659 : f32
// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 0.785398185 : f32
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 1.57079637 : f32
// CHECK-DAG: %[[VAL_16:.*]] = math.absf %[[VAL_0]] : f32
// CHECK-DAG: %[[VAL_17:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_2]] : f32
// CHECK-DAG: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_1]] : f32
// CHECK-DAG: %[[VAL_19:.*]] = arith.subf %[[VAL_16]], %[[VAL_1]] : f32
// CHECK-DAG: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_16]] : f32
// CHECK-DAG: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_1]] : f32
// CHECK-DAG: %[[VAL_22:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_3]] : f32
// CHECK-DAG: %[[VAL_23:.*]] = arith.select %[[VAL_22]], %[[VAL_1]], %[[VAL_20]] : f32
// CHECK-DAG: %[[VAL_24:.*]] = arith.select %[[VAL_22]], %[[VAL_16]], %[[VAL_21]] : f32
// CHECK-DAG: %[[VAL_25:.*]] = arith.divf %[[VAL_23]], %[[VAL_24]] : f32
// CHECK-DAG: %[[VAL_26:.*]] = arith.mulf %[[VAL_25]], %[[VAL_25]] : f32
// CHECK-DAG: %[[VAL_27:.*]] = math.fma %[[VAL_26]], %[[VAL_4]], %[[VAL_5]] : f32
// CHECK-DAG: %[[VAL_28:.*]] = math.fma %[[VAL_26]], %[[VAL_27]], %[[VAL_6]] : f32
// CHECK-DAG: %[[VAL_29:.*]] = math.fma %[[VAL_26]], %[[VAL_28]], %[[VAL_7]] : f32
// CHECK-DAG: %[[VAL_30:.*]] = math.fma %[[VAL_26]], %[[VAL_29]], %[[VAL_8]] : f32
// CHECK-DAG: %[[VAL_31:.*]] = arith.mulf %[[VAL_30]], %[[VAL_26]] : f32
// CHECK-DAG: %[[VAL_32:.*]] = math.fma %[[VAL_26]], %[[VAL_9]], %[[VAL_10]] : f32
// CHECK-DAG: %[[VAL_33:.*]] = math.fma %[[VAL_26]], %[[VAL_32]], %[[VAL_11]] : f32
// CHECK-DAG: %[[VAL_34:.*]] = math.fma %[[VAL_26]], %[[VAL_33]], %[[VAL_12]] : f32
// CHECK-DAG: %[[VAL_35:.*]] = math.fma %[[VAL_26]], %[[VAL_34]], %[[VAL_13]] : f32
// CHECK-DAG: %[[VAL_36:.*]] = arith.divf %[[VAL_31]], %[[VAL_35]] : f32
// CHECK-DAG: %[[VAL_37:.*]] = math.fma %[[VAL_36]], %[[VAL_25]], %[[VAL_25]] : f32
// CHECK-DAG: %[[VAL_38:.*]] = arith.addf %[[VAL_37]], %[[VAL_14]] : f32
// CHECK-DAG: %[[VAL_39:.*]] = arith.select %[[VAL_17]], %[[VAL_38]], %[[VAL_37]] : f32
// CHECK-DAG: %[[VAL_40:.*]] = arith.subf %[[VAL_15]], %[[VAL_37]] : f32
// CHECK-DAG: %[[VAL_41:.*]] = arith.select %[[VAL_22]], %[[VAL_40]], %[[VAL_39]] : f32
// CHECK-DAG: %[[VAL_42:.*]] = math.copysign %[[VAL_41]], %[[VAL_0]] : f32
// CHECK: return %[[VAL_42]] : f3
func.func @atan_scalar(%arg0: f32) -> f32 {
%0 = math.atan %arg0 : f32
return %0 : f32
@ -612,59 +638,75 @@ func.func @atan_scalar(%arg0: f32) -> f32 {
// CHECK-LABEL: @atan2_scalar
// ATan approximation:
// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00
// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831
// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335
// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099
// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987
// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637
// CHECK-DAG: %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32
// CHECK-DAG: %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32
// CHECK-DAG: %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]]
// CHECK-DAG: %[[ABS:.+]] = math.absf %[[RATIO]]
// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
// CHECK-DAG: %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]]
// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]]
// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]]
// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]]
// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]]
// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]]
// CHECK-DAG: %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]]
// CHECK-DAG: %[[ATAN:.+]] = math.copysign %[[EST]], %[[RATIO]]
// Handle the case of x < 0:
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00
// CHECK-DAG: %[[PI:.+]] = arith.constant 3.14159274
// CHECK-DAG: %[[ADD_PI:.+]] = arith.addf %[[ATAN]], %[[PI]]
// CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]]
// CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]]
// CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]]
// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]]
// CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]]
// Handle PI / 2 edge case:
// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]]
// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]]
// CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]]
// CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]]
// Handle -PI / 2 edge case:
// CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637
// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]]
// CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]]
// CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]]
// Handle Nan edgecase:
// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]]
// CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]]
// CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000
// CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]]
// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]]
// CHECK: return %[[RET]]
// CHECK-SAME: %[[VAL_0:.*]]: f16,
// CHECK-SAME: %[[VAL_1:.*]]: f16)
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6.600000e-01 : f32
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2.41421366 : f32
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant -0.875060856 : f32
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant -16.1575375 : f32
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant -75.0085601 : f32
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant -122.886665 : f32
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant -64.8502197 : f32
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 24.8584652 : f32
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 165.027008 : f32
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 432.881073 : f32
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 485.390411 : f32
// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 194.550659 : f32
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 0.785398185 : f32
// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 1.57079637 : f32
// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 3.14159274 : f32
// CHECK-DAG: %[[VAL_19:.*]] = arith.constant -1.57079637 : f32
// CHECK-DAG: %[[VAL_20:.*]] = arith.constant 0x7FC00000 : f32
// CHECK-DAG: %[[VAL_21:.*]] = arith.extf %[[VAL_0]] : f16 to f32
// CHECK-DAG: %[[VAL_22:.*]] = arith.extf %[[VAL_1]] : f16 to f32
// CHECK-DAG: %[[VAL_23:.*]] = arith.divf %[[VAL_21]], %[[VAL_22]] : f32
// CHECK-DAG: %[[VAL_24:.*]] = math.absf %[[VAL_23]] : f32
// CHECK-DAG: %[[VAL_25:.*]] = arith.cmpf ogt, %[[VAL_24]], %[[VAL_3]] : f32
// CHECK-DAG: %[[VAL_26:.*]] = arith.addf %[[VAL_24]], %[[VAL_2]] : f32
// CHECK-DAG: %[[VAL_27:.*]] = arith.subf %[[VAL_24]], %[[VAL_2]] : f32
// CHECK-DAG: %[[VAL_28:.*]] = arith.select %[[VAL_25]], %[[VAL_27]], %[[VAL_24]] : f32
// CHECK-DAG: %[[VAL_29:.*]] = arith.select %[[VAL_25]], %[[VAL_26]], %[[VAL_2]] : f32
// CHECK-DAG: %[[VAL_30:.*]] = arith.cmpf ogt, %[[VAL_24]], %[[VAL_4]] : f32
// CHECK-DAG: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_2]], %[[VAL_28]] : f32
// CHECK-DAG: %[[VAL_32:.*]] = arith.select %[[VAL_30]], %[[VAL_24]], %[[VAL_29]] : f32
// CHECK-DAG: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[VAL_32]] : f32
// CHECK-DAG: %[[VAL_34:.*]] = arith.mulf %[[VAL_33]], %[[VAL_33]] : f32
// CHECK-DAG: %[[VAL_35:.*]] = math.fma %[[VAL_34]], %[[VAL_5]], %[[VAL_6]] : f32
// CHECK-DAG: %[[VAL_36:.*]] = math.fma %[[VAL_34]], %[[VAL_35]], %[[VAL_7]] : f32
// CHECK-DAG: %[[VAL_37:.*]] = math.fma %[[VAL_34]], %[[VAL_36]], %[[VAL_8]] : f32
// CHECK-DAG: %[[VAL_38:.*]] = math.fma %[[VAL_34]], %[[VAL_37]], %[[VAL_9]] : f32
// CHECK-DAG: %[[VAL_39:.*]] = arith.mulf %[[VAL_38]], %[[VAL_34]] : f32
// CHECK-DAG: %[[VAL_40:.*]] = math.fma %[[VAL_34]], %[[VAL_10]], %[[VAL_11]] : f32
// CHECK-DAG: %[[VAL_41:.*]] = math.fma %[[VAL_34]], %[[VAL_40]], %[[VAL_12]] : f32
// CHECK-DAG: %[[VAL_42:.*]] = math.fma %[[VAL_34]], %[[VAL_41]], %[[VAL_13]] : f32
// CHECK-DAG: %[[VAL_43:.*]] = math.fma %[[VAL_34]], %[[VAL_42]], %[[VAL_14]] : f32
// CHECK-DAG: %[[VAL_44:.*]] = arith.divf %[[VAL_39]], %[[VAL_43]] : f32
// CHECK-DAG: %[[VAL_45:.*]] = math.fma %[[VAL_44]], %[[VAL_33]], %[[VAL_33]] : f32
// CHECK-DAG: %[[VAL_46:.*]] = arith.addf %[[VAL_45]], %[[VAL_15]] : f32
// CHECK-DAG: %[[VAL_47:.*]] = arith.select %[[VAL_25]], %[[VAL_46]], %[[VAL_45]] : f32
// CHECK-DAG: %[[VAL_48:.*]] = arith.subf %[[VAL_16]], %[[VAL_45]] : f32
// CHECK-DAG: %[[VAL_49:.*]] = arith.select %[[VAL_30]], %[[VAL_48]], %[[VAL_47]] : f32
// CHECK-DAG: %[[VAL_50:.*]] = math.copysign %[[VAL_49]], %[[VAL_23]] : f32
// CHECK-DAG: %[[VAL_51:.*]] = arith.addf %[[VAL_50]], %[[VAL_18]] : f32
// CHECK-DAG: %[[VAL_52:.*]] = arith.subf %[[VAL_50]], %[[VAL_18]] : f32
// CHECK-DAG: %[[VAL_53:.*]] = arith.cmpf ogt, %[[VAL_50]], %[[VAL_17]] : f32
// CHECK-DAG: %[[VAL_54:.*]] = arith.select %[[VAL_53]], %[[VAL_52]], %[[VAL_51]] : f32
// CHECK-DAG: %[[VAL_55:.*]] = arith.cmpf ogt, %[[VAL_22]], %[[VAL_17]] : f32
// CHECK-DAG: %[[VAL_56:.*]] = arith.select %[[VAL_55]], %[[VAL_50]], %[[VAL_54]] : f32
// CHECK-DAG: %[[VAL_57:.*]] = arith.cmpf oeq, %[[VAL_22]], %[[VAL_17]] : f32
// CHECK-DAG: %[[VAL_58:.*]] = arith.cmpf ogt, %[[VAL_21]], %[[VAL_17]] : f32
// CHECK-DAG: %[[VAL_59:.*]] = arith.andi %[[VAL_57]], %[[VAL_58]] : i1
// CHECK-DAG: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_16]], %[[VAL_56]] : f32
// CHECK-DAG: %[[VAL_61:.*]] = arith.cmpf olt, %[[VAL_21]], %[[VAL_17]] : f32
// CHECK-DAG: %[[VAL_62:.*]] = arith.andi %[[VAL_57]], %[[VAL_61]] : i1
// CHECK-DAG: %[[VAL_63:.*]] = arith.select %[[VAL_62]], %[[VAL_19]], %[[VAL_60]] : f32
// CHECK-DAG: %[[VAL_64:.*]] = arith.cmpf oeq, %[[VAL_21]], %[[VAL_17]] : f32
// CHECK-DAG: %[[VAL_65:.*]] = arith.andi %[[VAL_57]], %[[VAL_64]] : i1
// CHECK-DAG: %[[VAL_66:.*]] = arith.select %[[VAL_65]], %[[VAL_20]], %[[VAL_63]] : f32
// CHECK-DAG: %[[VAL_67:.*]] = arith.truncf %[[VAL_66]] : f32 to f16
// CHECK: return %[[VAL_67]] : f1
func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
%0 = math.atan2 %arg0, %arg1 : f16
return %0 : f16

View File

@ -471,19 +471,19 @@ func.func @atan_f32(%a : f32) {
}
func.func @atan() {
// CHECK: -0.785184
// CHECK: -0.785398
%0 = arith.constant -1.0 : f32
call @atan_f32(%0) : (f32) -> ()
// CHECK: 0.785184
// CHECK: 0.785398
%1 = arith.constant 1.0 : f32
call @atan_f32(%1) : (f32) -> ()
// CHECK: -0.463643
// CHECK: -0.463648
%2 = arith.constant -0.5 : f32
call @atan_f32(%2) : (f32) -> ()
// CHECK: 0.463643
// CHECK: 0.463648
%3 = arith.constant 0.5 : f32
call @atan_f32(%3) : (f32) -> ()
@ -548,7 +548,7 @@ func.func @atan2() {
// CHECK: -1.10715
call @atan2_f32(%neg_two, %one) : (f32, f32) -> ()
// CHECK: 0.463643
// CHECK: 0.463648
call @atan2_f32(%one, %two) : (f32, f32) -> ()
// CHECK: 2.67795
@ -561,7 +561,7 @@ func.func @atan2() {
%y11 = arith.constant -1.0 : f32
call @atan2_f32(%neg_one, %neg_two) : (f32, f32) -> ()
// CHECK: -0.463643
// CHECK: -0.463648
call @atan2_f32(%neg_one, %two) : (f32, f32) -> ()
return