Fix fres, add NI, more test cases

`fres` is reimplemented in an unoptimized way but it handles all edge cases
The specific magic numbers were binary searched on hardware
Someday NI maybe should be an option with the normal fres tests and not only ps_res
This commit is contained in:
Geotale
2024-09-23 17:28:19 -05:00
parent e8a16a5ce2
commit 16dd5cbb90
3 changed files with 61 additions and 49 deletions

View File

@@ -192,7 +192,7 @@ double frsqrte_expected(double val)
}
double fres_expected(double val)
double fres_expected(double val, bool ni)
{
static const s32 estimate_base[] = {
0xfff000, 0xf07000, 0xe1d400, 0xd41000, 0xc71000, 0xbac400, 0xaf2000, 0xa41000,
@@ -212,13 +212,12 @@ double fres_expected(double val)
u32 vali;
};
u64 full_bits = Common::BitCast<u64>(val);
valf = RoundToFloatWithMode(val, RoundingMode::TowardsZero);
u32 mantissa = vali & FLOAT_FRAC;
u32 sign = vali & FLOAT_SIGN;
s32 exponent = static_cast<s32>(vali & FLOAT_EXP);
u32 mantissa = static_cast<u32>((full_bits & DOUBLE_FRAC) >> (DOUBLE_FRAC_WIDTH - FLOAT_FRAC_WIDTH));
u32 sign = static_cast<u32>(full_bits >> 32) & FLOAT_SIGN;
s32 exponent = static_cast<s32>(((full_bits & DOUBLE_EXP) >> DOUBLE_FRAC_WIDTH) - 0x380);
// Special case 0
if (exponent == 0 && mantissa < 0x200000)
if ((full_bits & DOUBLE_EXP) <= 0x37e0000000000000)
{
if ((full_bits & ~DOUBLE_SIGN) == 0)
{
@@ -231,28 +230,19 @@ double fres_expected(double val)
}
}
// Special case NaN-ish numbers
if ((full_bits & DOUBLE_EXP) >= 0x47f0000000000000ULL)
u64 max_float = ni ? 0x47d0000000000000ULL : 0x4940000000000000ULL;
// Special case huge and NaN-ish numbers
if ((full_bits & DOUBLE_EXP) >= max_float)
{
// If it's not NaN, it's infinite!
if (valf == valf)
// If it's not NaN, it's infinite! (Or just so big it gets cast down to 0)
if (val == val)
return sign ? -0.0 : 0.0;
return 0.0 + val;
}
// Number is denormal, shift the mantissa and adjust the exponent
if (exponent == 0)
{
mantissa <<= 1;
while ((mantissa & FLOAT_EXP) == 0) {
mantissa <<= 1;
exponent -= static_cast<s32>(1 << FLOAT_FRAC_WIDTH);
}
mantissa &= FLOAT_FRAC;
}
exponent = (253 << FLOAT_FRAC_WIDTH) - exponent;
exponent = 253 - exponent;
u32 key = mantissa >> 18;
u32 new_mantissa = static_cast<u32>(estimate_base[key] + estimate_dec[key] * static_cast<s32>((mantissa >> 8) & 0x3ff)) >> 1;
@@ -260,13 +250,21 @@ double fres_expected(double val)
if (exponent <= 0)
{
// Result is subnormal, format it properly!
u32 shift = 1 + (static_cast<u32>(-exponent) >> FLOAT_FRAC_WIDTH);
vali = sign | (((1 << FLOAT_FRAC_WIDTH) | new_mantissa) >> shift);
if (ni)
{
// Flush to 0 for inexact denormals
vali = sign;
}
else
{
u32 shift = 1 + static_cast<u32>(-exponent);
vali = sign | (((1 << FLOAT_FRAC_WIDTH) | new_mantissa) >> shift);
}
}
else
{
// Result is normal, just string things together
vali = sign | static_cast<u32>(exponent) | new_mantissa;
vali = sign | static_cast<u32>(exponent << FLOAT_FRAC_WIDTH) | new_mantissa;
}
return static_cast<double>(valf);
}

View File

@@ -288,14 +288,14 @@ static void Sum1Test(const u64* input_ptr, RoundingMode rounding_mode)
expected_ps0, Common::BitCast<double>(expected_ps0));
}
static void ResTest(const u64* input_ptr)
static void ResTest(const u64* input_ptr, bool ni)
{
double result_ps0;
double result_ps1;
u64 input = *input_ptr;
double input_float = Common::BitCast<double>(input);
double expected_ps0_float = fres_expected(input_float);
double expected_ps0_float = fres_expected(input_float, ni);
u64 expected_ps0 = TruncateMantissaBits(Common::BitCast<u64>(expected_ps0_float));
double expected_ps1_float = expected_ps0_float;
u64 expected_ps1 = expected_ps0;
@@ -398,6 +398,11 @@ static void PSMoveTest()
0x3690000000000000, // Min single denormal / 2
0x36a8000000000000, // Min single denormal * 3 / 2
0x36a8000000000000, // Min single denormal * 3 / 2
0x36b0000000000000, // Barely over not min accounted reciprocal
0x36c0000000000000, // Again barely over not min accounted reciprocal
0x37e0000000000000, // High 0 reciprocal
0x37f0000000000000, // Not 0 reciprocal
0x3800000000000000, // Very not 0 reciprocal
0x000fffffc0000000, // Not max double denormal
0x380fffff80000000, // Not max single denormal
0x000fffffffffffff, // Max double denormal
@@ -425,6 +430,11 @@ static void PSMoveTest()
0x1fffffffd0000000, // Similar denormal
0x1fffffffe0000000, // Similar denormal again (ties even)
0x1fffffffe0000001, // Similar denormal yet again (round up)
0x47cfffffc0000000, // Before fres stops working
0x47d0000000000000, // After fres stops working
0x47e0000000000000, // Another after fres stops working
0x47f0000000000000, // Yet another after fres stops working for good measure
0x4800000000000000, // ...?
0x0123456789abcdef, // Random
0x76543210fedcba09, // Random
@@ -436,32 +446,36 @@ static void PSMoveTest()
0x7ff5555555555555, // SNaN (extra bits)
};
const u64 max_rounding_mode = 4;
for (u64 rounding_mode_idx = 0; rounding_mode_idx < max_rounding_mode; ++rounding_mode_idx)
for (u64 ni = 0; ni < 2; ++ni)
{
asm volatile("mtfsf 7, %0" :: "f"(rounding_mode_idx));
network_printf("Rounding mode: %llu\n", rounding_mode_idx);
const u64 max_rounding_mode = 4;
for (u64 sign = 0; sign < 2; ++sign)
for (u64 rounding_mode_idx = 0; rounding_mode_idx < max_rounding_mode; ++rounding_mode_idx)
{
RoundingMode rounding_mode = static_cast<RoundingMode>(rounding_mode_idx);
u64 rounding_setting = ni << 2 | rounding_mode_idx;
asm volatile("mtfsf 7, %0" :: "f"(rounding_setting));
network_printf("Rounding mode: %llu | NI: %llu\n", rounding_mode_idx, ni);
for (size_t i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i)
for (u64 sign = 0; sign < 2; ++sign)
{
const u64 input = inputs[i] | (sign << DOUBLE_SIGN_SHIFT);
const u64 *input_ref = &input;
RoundingMode rounding_mode = static_cast<RoundingMode>(rounding_mode_idx);
MergeTest(input_ref, rounding_mode);
MrTest(input_ref, rounding_mode);
NegTest(input_ref, rounding_mode);
AbsTest(input_ref, rounding_mode);
NabsTest(input_ref, rounding_mode);
SelTest(input_ref, rounding_mode);
Sum0Test(input_ref);
Sum1Test(input_ref, rounding_mode);
ResTest(input_ref);
RsqrteTest(input_ref);
for (size_t i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i)
{
const u64 input = inputs[i] | (sign << DOUBLE_SIGN_SHIFT);
const u64 *input_ref = &input;
MergeTest(input_ref, rounding_mode);
MrTest(input_ref, rounding_mode);
NegTest(input_ref, rounding_mode);
AbsTest(input_ref, rounding_mode);
NabsTest(input_ref, rounding_mode);
SelTest(input_ref, rounding_mode);
Sum0Test(input_ref);
Sum1Test(input_ref, rounding_mode);
ResTest(input_ref, ni != 0);
RsqrteTest(input_ref);
}
}
}
}

View File

@@ -48,7 +48,7 @@ static void ReciprocalTest()
break;
testi = i << 32;
expectedf = fres_expected(testf);
expectedf = fres_expected(testf, true);
testf = fres_intrinsic(testf);
DO_TEST(testi == expectedi, "Bad fres {} {} {} {} {}", i, testf, testi, expectedf,
expectedi);