Implement midpoint for floating point types. Reviewed as https://reviews.llvm.org/D61014.

git-svn-id: https://llvm.org/svn/llvm-project/libcxx/trunk@359184 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Marshall Clow 2019-04-25 12:11:43 +00:00
parent 26ff1b3321
commit 321a1f890d
8 changed files with 228 additions and 1 deletions

View File

@ -188,6 +188,8 @@ Status
------------------------------------------------- -----------------
``__cpp_lib_generic_unordered_lookup`` *unimplemented*
------------------------------------------------- -----------------
``__cpp_lib_interpolate`` ``201902L``
------------------------------------------------- -----------------
``__cpp_lib_is_constant_evaluated`` ``201811L``
------------------------------------------------- -----------------
``__cpp_lib_list_remove_return_type`` *unimplemented*

View File

@ -145,6 +145,7 @@ floating_point midpoint(floating_point a, floating_point b); // C++20
#include <iterator>
#include <limits> // for numeric_limits
#include <functional>
#include <cmath> // for isnormal
#include <version>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@ -552,7 +553,24 @@ midpoint(_TPtr __a, _TPtr __b) noexcept
{
return __a + _VSTD::midpoint(ptrdiff_t(0), __b - __a);
}
#endif
template <typename _Tp>
int __sign(_Tp __val) {
return (_Tp(0) < __val) - (__val < _Tp(0));
}
template <class _Fp>
_LIBCPP_INLINE_VISIBILITY constexpr
enable_if_t<is_floating_point_v<_Fp>, _Fp>
midpoint(_Fp __a, _Fp __b) noexcept
{
return isnormal(__a) && isnormal(__b)
&& ((__sign(__a) != __sign(__b)) || ((numeric_limits<_Fp>::max() - abs(__a)) < abs(__b)))
? __a / 2 + __b / 2
: (__a + __b) / 2;
}
#endif // _LIBCPP_STD_VER > 17
_LIBCPP_END_NAMESPACE_STD

View File

@ -58,6 +58,7 @@ __cpp_lib_hypot 201603L <cmath>
__cpp_lib_incomplete_container_elements 201505L <forward_list> <list> <vector>
__cpp_lib_integer_sequence 201304L <utility>
__cpp_lib_integral_constant_callable 201304L <type_traits>
__cpp_lib_interpolate 201902L <numeric>
__cpp_lib_invoke 201411L <functional>
__cpp_lib_is_aggregate 201703L <type_traits>
__cpp_lib_is_constant_evaluated 201811L <type_traits>
@ -222,6 +223,7 @@ __cpp_lib_void_t 201411L <type_traits>
// # define __cpp_lib_destroying_delete 201806L
# define __cpp_lib_erase_if 201811L
// # define __cpp_lib_generic_unordered_lookup 201811L
# define __cpp_lib_interpolate 201902L
# if !defined(_LIBCPP_HAS_NO_BUILTIN_IS_CONSTANT_EVALUATED)
# define __cpp_lib_is_constant_evaluated 201811L
# endif

View File

@ -15,6 +15,7 @@
/* Constant Value
__cpp_lib_gcd_lcm 201606L [C++17]
__cpp_lib_interpolate 201902L [C++2a]
__cpp_lib_parallel_algorithm 201603L [C++17]
*/
@ -27,6 +28,10 @@
# error "__cpp_lib_gcd_lcm should not be defined before c++17"
# endif
# ifdef __cpp_lib_interpolate
# error "__cpp_lib_interpolate should not be defined before c++2a"
# endif
# ifdef __cpp_lib_parallel_algorithm
# error "__cpp_lib_parallel_algorithm should not be defined before c++17"
# endif
@ -37,6 +42,10 @@
# error "__cpp_lib_gcd_lcm should not be defined before c++17"
# endif
# ifdef __cpp_lib_interpolate
# error "__cpp_lib_interpolate should not be defined before c++2a"
# endif
# ifdef __cpp_lib_parallel_algorithm
# error "__cpp_lib_parallel_algorithm should not be defined before c++17"
# endif
@ -50,6 +59,10 @@
# error "__cpp_lib_gcd_lcm should have the value 201606L in c++17"
# endif
# ifdef __cpp_lib_interpolate
# error "__cpp_lib_interpolate should not be defined before c++2a"
# endif
# if !defined(_LIBCPP_VERSION)
# ifndef __cpp_lib_parallel_algorithm
# error "__cpp_lib_parallel_algorithm should be defined in c++17"
@ -72,6 +85,13 @@
# error "__cpp_lib_gcd_lcm should have the value 201606L in c++2a"
# endif
# ifndef __cpp_lib_interpolate
# error "__cpp_lib_interpolate should be defined in c++2a"
# endif
# if __cpp_lib_interpolate != 201902L
# error "__cpp_lib_interpolate should have the value 201902L in c++2a"
# endif
# if !defined(_LIBCPP_VERSION)
# ifndef __cpp_lib_parallel_algorithm
# error "__cpp_lib_parallel_algorithm should be defined in c++2a"

View File

@ -50,6 +50,7 @@
__cpp_lib_incomplete_container_elements 201505L [C++17]
__cpp_lib_integer_sequence 201304L [C++14]
__cpp_lib_integral_constant_callable 201304L [C++14]
__cpp_lib_interpolate 201902L [C++2a]
__cpp_lib_invoke 201411L [C++17]
__cpp_lib_is_aggregate 201703L [C++17]
__cpp_lib_is_constant_evaluated 201811L [C++2a]
@ -248,6 +249,10 @@
# error "__cpp_lib_integral_constant_callable should not be defined before c++14"
# endif
# ifdef __cpp_lib_interpolate
# error "__cpp_lib_interpolate should not be defined before c++2a"
# endif
# ifdef __cpp_lib_invoke
# error "__cpp_lib_invoke should not be defined before c++17"
# endif
@ -596,6 +601,10 @@
# error "__cpp_lib_integral_constant_callable should have the value 201304L in c++14"
# endif
# ifdef __cpp_lib_interpolate
# error "__cpp_lib_interpolate should not be defined before c++2a"
# endif
# ifdef __cpp_lib_invoke
# error "__cpp_lib_invoke should not be defined before c++17"
# endif
@ -1082,6 +1091,10 @@
# error "__cpp_lib_integral_constant_callable should have the value 201304L in c++17"
# endif
# ifdef __cpp_lib_interpolate
# error "__cpp_lib_interpolate should not be defined before c++2a"
# endif
# ifndef __cpp_lib_invoke
# error "__cpp_lib_invoke should be defined in c++17"
# endif
@ -1778,6 +1791,13 @@
# error "__cpp_lib_integral_constant_callable should have the value 201304L in c++2a"
# endif
# ifndef __cpp_lib_interpolate
# error "__cpp_lib_interpolate should be defined in c++2a"
# endif
# if __cpp_lib_interpolate != 201902L
# error "__cpp_lib_interpolate should have the value 201902L in c++2a"
# endif
# ifndef __cpp_lib_invoke
# error "__cpp_lib_invoke should be defined in c++2a"
# endif

View File

@ -0,0 +1,113 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// UNSUPPORTED: c++98, c++03, c++11, c++14, c++17
// <numeric>
// template <class _Float>
// _Tp midpoint(_Float __a, _Float __b) noexcept
//
#include <numeric>
#include <cassert>
#include "test_macros.h"
#include "fp_compare.h"
// Totally arbitrary picks for precision
template <typename T>
constexpr T fp_error_pct();
template <>
constexpr float fp_error_pct<float>() { return 1.0e-4f; }
template <>
constexpr double fp_error_pct<double>() { return 1.0e-12; }
template <>
constexpr long double fp_error_pct<long double>() { return 1.0e-13l; }
template <typename T>
void fp_test()
{
ASSERT_SAME_TYPE(T, decltype(std::midpoint(T(), T())));
ASSERT_NOEXCEPT( std::midpoint(T(), T()));
constexpr T maxV = std::numeric_limits<T>::max();
constexpr T minV = std::numeric_limits<T>::min();
// Things that can be compared exactly
assert((std::midpoint(T(0), T(0)) == T(0)));
assert((std::midpoint(T(2), T(4)) == T(3)));
assert((std::midpoint(T(4), T(2)) == T(3)));
assert((std::midpoint(T(3), T(4)) == T(3.5)));
assert((std::midpoint(T(0), T(0.4)) == T(0.2)));
// Things that can't be compared exactly
constexpr T pct = fp_error_pct<T>();
assert((fptest_close_pct(std::midpoint(T( 1.3), T(11.4)), T( 6.35), pct)));
assert((fptest_close_pct(std::midpoint(T(11.33), T(31.45)), T(21.39), pct)));
assert((fptest_close_pct(std::midpoint(T(-1.3), T(11.4)), T( 5.05), pct)));
assert((fptest_close_pct(std::midpoint(T(11.4), T(-1.3)), T( 5.05), pct)));
assert((fptest_close_pct(std::midpoint(T(0.1), T(0.4)), T(0.25), pct)));
assert((fptest_close_pct(std::midpoint(T(11.2345), T(14.5432)), T(12.88885), pct)));
// From e to pi
assert((fptest_close_pct(std::midpoint(T(2.71828182845904523536028747135266249775724709369995),
T(3.14159265358979323846264338327950288419716939937510)),
T(2.92993724102441923691146542731608269097720824653752), pct)));
assert((fptest_close_pct(std::midpoint(maxV, T(0)), maxV/2, pct)));
assert((fptest_close_pct(std::midpoint(T(0), maxV), maxV/2, pct)));
assert((fptest_close_pct(std::midpoint(minV, T(0)), minV/2, pct)));
assert((fptest_close_pct(std::midpoint(T(0), minV), minV/2, pct)));
assert((fptest_close_pct(std::midpoint(maxV, maxV), maxV, pct)));
assert((fptest_close_pct(std::midpoint(minV, minV), minV, pct)));
// Denormalized values
// TODO
// Check two values "close to each other"
T d1 = 3.14;
T d0 = std::nexttoward(d1, T(2));
T d2 = std::nexttoward(d1, T(5));
assert(d0 < d1); // sanity checking
assert(d1 < d2); // sanity checking
// Since there's nothing in between, the midpoint has to be one or the other
T res;
res = std::midpoint(d0, d1);
assert(res == d0 || res == d1);
assert(d0 <= res);
assert(res <= d1);
res = std::midpoint(d1, d0);
assert(res == d0 || res == d1);
assert(d0 <= res);
assert(res <= d1);
res = std::midpoint(d1, d2);
assert(res == d1 || res == d2);
assert(d1 <= res);
assert(res <= d2);
res = std::midpoint(d2, d1);
assert(res == d1 || res == d2);
assert(d1 <= res);
assert(res <= d2);
}
int main (int, char**)
{
fp_test<float>();
fp_test<double>();
fp_test<long double>();
return 0;
}

46
test/support/fp_compare.h Normal file
View File

@ -0,0 +1,46 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef SUPPORT_FP_COMPARE_H
#define SUPPORT_FP_COMPARE_H
#include <cmath> // for std::abs
#include <algorithm> // for std::max
#include <cassert>
// See https://www.boost.org/doc/libs/1_70_0/libs/test/doc/html/boost_test/testing_tools/extended_comparison/floating_point/floating_points_comparison_theory.html
template<typename T>
bool fptest_close(T val, T expected, T eps)
{
constexpr T zero = T(0);
assert(eps >= zero);
// Handle the zero cases
if (eps == zero) return val == expected;
if (val == zero) return std::abs(expected) <= eps;
if (expected == zero) return std::abs(val) <= eps;
return std::abs(val - expected) < eps
&& std::abs(val - expected)/std::abs(val) < eps;
}
template<typename T>
bool fptest_close_pct(T val, T expected, T percent)
{
constexpr T zero = T(0);
assert(percent >= zero);
// Handle the zero cases
if (percent == zero) return val == expected;
T eps = (percent / T(100)) * std::max(std::abs(val), std::abs(expected));
return fptest_close(val, expected, eps);
}
#endif // SUPPORT_FP_COMPARE_H

View File

@ -565,6 +565,12 @@ feature_test_macros = sorted([ add_version_header(x) for x in [
"depends": "!defined(_LIBCPP_HAS_NO_THREADS)",
"internal_depends": "!defined(_LIBCPP_HAS_NO_THREADS)",
},
{"name": "__cpp_lib_interpolate",
"values": {
"c++2a": 201902L,
},
"headers": ["numeric"],
},
]], key=lambda tc: tc["name"])
def get_std_dialects():