[flang] Fold real-valued DIM()

Fold references to the intrinsic function DIM with constant real
arguments.  And clean up folding of comparisons with NaNs to address
a problem noticed in testing -- NaNs should successfully compare
unequal to all values, including themselves, instead of failing all
comparisons.

Differential Revision: https://reviews.llvm.org/D125146
This commit is contained in:
Peter Klausler 2022-05-04 16:35:31 -07:00
parent ad3b358180
commit 9e50168be4
6 changed files with 50 additions and 5 deletions

View File

@ -120,7 +120,7 @@ static constexpr bool Satisfies(RelationalOperator op, Relation relation) {
case Relation::Greater:
return Satisfies(op, Ordering::Greater);
case Relation::Unordered:
return false;
return op == RelationalOperator::NE;
}
return false; // silence g++ warning
}

View File

@ -128,6 +128,10 @@ public:
ValueWithRealFlags<Real> HYPOT(
const Real &, Rounding rounding = defaultRounding) const;
// DIM(X,Y) = MAX(X-Y, 0)
ValueWithRealFlags<Real> DIM(
const Real &, Rounding rounding = defaultRounding) const;
template <typename INT> constexpr INT EXPONENT() const {
if (Exponent() == maxExponent) {
return INT::HUGE();

View File

@ -129,6 +129,12 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
}
return y.value;
}));
} else if (name == "dim") {
return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
ScalarFunc<T, T, T>(
[](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> {
return x.DIM(y).value;
}));
} else if (name == "dprod") {
if (auto scalars{GetScalarConstantArguments<T, T>(context, args)}) {
return Fold(context,
@ -284,8 +290,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return result.value;
}));
}
// TODO: dim, dot_product, fraction, matmul,
// modulo, norm2, set_exponent, transfer,
// TODO: dot_product, fraction, matmul, norm2, set_exponent, transfer
return Expr<T>{std::move(funcRef)};
}

View File

@ -422,6 +422,21 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::HYPOT(
return result;
}
template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::DIM(
const Real &y, Rounding rounding) const {
ValueWithRealFlags<Real> result;
if (IsNotANumber() || y.IsNotANumber()) {
result.flags.set(RealFlag::InvalidArgument);
result.value = NotANumber();
} else if (Compare(y) == Relation::Greater) {
result = Subtract(y, rounding);
} else {
// result is already zero
}
return result;
}
template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::ToWholeNumber(
common::RoundingMode mode) const {

View File

@ -0,0 +1,17 @@
! RUN: %python %S/test_folding.py %s %flang_fc1
! Tests folding of DIM()
module m
logical, parameter :: test_i1 = dim(0, 0) == 0
logical, parameter :: test_i2 = dim(1, 2) == 0
logical, parameter :: test_i3 = dim(2, 1) == 1
logical, parameter :: test_i4 = dim(2, -1) == 3
logical, parameter :: test_i5 = dim(-1, 2) == 0
logical, parameter :: test_a1 = dim(0., 0.) == 0.
logical, parameter :: test_a2 = dim(1., 2.) == 0.
logical, parameter :: test_a3 = dim(2., 1.) == 1.
logical, parameter :: test_a4 = dim(2., -1.) == 3.
logical, parameter :: test_a5 = dim(-1., 2.) == 0.
!WARN: warning: invalid argument on division
real, parameter :: nan = 0./0.
logical, parameter :: test_a6 = dim(nan, 1.) /= dim(nan, 1.)
end module

View File

@ -85,7 +85,9 @@ module real_tests
real(4), parameter :: r4_ninf = -1._4/0._4
logical, parameter :: test_r4_nan_parentheses1 = .NOT.(((r4_nan)).EQ.r4_nan)
logical, parameter :: test_r4_nan_parentheses2 = .NOT.(((r4_nan)).NE.r4_nan)
logical, parameter :: test_r4_nan_parentheses2 = .NOT.(((r4_nan)).LT.r4_nan)
logical, parameter :: test_r4_nan_parentheses3 = .NOT.(((r4_nan)).GT.r4_nan)
logical, parameter :: test_r4_nan_parentheses4 = ((r4_nan)).NE.r4_nan
logical, parameter :: test_r4_pinf_parentheses = ((r4_pinf)).EQ.r4_pinf
logical, parameter :: test_r4_ninf_parentheses = ((r4_ninf)).EQ.r4_ninf
@ -251,7 +253,9 @@ module real_tests
! Invalid relational argument
logical, parameter :: test_nan_r4_eq1 = .NOT.(r4_nan.EQ.r4_nan)
logical, parameter :: test_nan_r4_ne1 = .NOT.(r4_nan.NE.r4_nan)
logical, parameter :: test_nan_r4_lt1 = .NOT.(r4_nan.LE.r4_nan)
logical, parameter :: test_nan_r4_gt1 = .NOT.(r4_nan.GT.r4_nan)
logical, parameter :: test_nan_r4_ne1 = r4_nan.NE.r4_nan
end module