From 39f1860dccf224f9151f5f3c7e86b84f153d142a Mon Sep 17 00:00:00 2001 From: Peter Klausler <35819229+klausler@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:58:19 -0700 Subject: [PATCH] [flang] Fold NORM2() (#66240) Fold references to the (relatively new) intrinsic function NORM2 at compilation time when the argument(s) are all constants. (Getting this done right involved some changes to the API of the accumulator function objects used by the DoReduction<> template, which rippled through some other reduction function folding code.) --- flang/lib/Evaluate/fold-integer.cpp | 36 +++--- flang/lib/Evaluate/fold-logical.cpp | 5 +- flang/lib/Evaluate/fold-real.cpp | 78 ++++++++++++- flang/lib/Evaluate/fold-reduction.h | 168 +++++++++++++++++++--------- flang/test/Evaluate/fold-norm2.f90 | 29 +++++ 5 files changed, 248 insertions(+), 68 deletions(-) create mode 100644 flang/test/Evaluate/fold-norm2.f90 diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp index fe38c81d9768..dedfc20a491c 100644 --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -264,6 +264,26 @@ Expr> UBOUND(FoldingContext &context, } // COUNT() +template class CountAccumulator { + using MaskT = Type; + +public: + CountAccumulator(const Constant &mask) : mask_{mask} {} + void operator()(Scalar &element, const ConstantSubscripts &at) { + if (mask_.At(at).IsTrue()) { + auto incremented{element.AddSigned(Scalar{1})}; + overflow_ |= incremented.overflow; + element = incremented.value; + } + } + bool overflow() const { return overflow_; } + void Done(Scalar &) const {} + +private: + const Constant &mask_; + bool overflow_{false}; +}; + template static Expr FoldCount(FoldingContext &context, FunctionRef &&ref) { using LogicalResult = Type; @@ -274,17 +294,9 @@ static Expr FoldCount(FoldingContext &context, FunctionRef &&ref) { : Folder{context}.Folding(arg[0])}) { std::optional dim; if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) { - bool overflow{false}; - auto accumulator{ - [&mask, &overflow](Scalar &element, const ConstantSubscripts &at) { - if (mask->At(at).IsTrue()) { - auto incremented{element.AddSigned(Scalar{1})}; - overflow |= incremented.overflow; - element = incremented.value; - } - }}; + CountAccumulator accumulator{*mask}; Constant result{DoReduction(*mask, dim, Scalar{}, accumulator)}; - if (overflow) { + if (accumulator.overflow()) { context.messages().Say( "Result of intrinsic function COUNT overflows its result type"_warn_en_US); } @@ -513,9 +525,7 @@ static Expr FoldBitReduction(FoldingContext &context, FunctionRef &&ref, if (std::optional> array{ ProcessReductionArgs(context, ref.arguments(), dim, identity, /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { - auto accumulator{[&](Scalar &element, const ConstantSubscripts &at) { - element = (element.*operation)(array->At(at)); - }}; + OperationAccumulator accumulator{*array, operation}; return Expr{DoReduction(*array, dim, identity, accumulator)}; } return Expr{std::move(ref)}; diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp index 95335f7f48bb..9fc42adf805f 100644 --- a/flang/lib/Evaluate/fold-logical.cpp +++ b/flang/lib/Evaluate/fold-logical.cpp @@ -28,14 +28,11 @@ static Expr FoldAllAnyParity(FoldingContext &context, FunctionRef &&ref, Scalar (Scalar::*operation)(const Scalar &) const, Scalar identity) { static_assert(T::category == TypeCategory::Logical); - using Element = Scalar; std::optional dim; if (std::optional> array{ ProcessReductionArgs(context, ref.arguments(), dim, identity, /*ARRAY(MASK)=*/0, /*DIM=*/1)}) { - auto accumulator{[&](Element &element, const ConstantSubscripts &at) { - element = (element.*operation)(array->At(at)); - }}; + OperationAccumulator accumulator{*array, operation}; return Expr{DoReduction(*array, dim, identity, accumulator)}; } return Expr{std::move(ref)}; diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index 671d897ef7b2..8e3ab1d8fd30 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -43,6 +43,80 @@ static Expr FoldTransformationalBessel( return Expr{std::move(funcRef)}; } +// NORM2 +template class Norm2Accumulator { + using T = Type; + +public: + Norm2Accumulator( + const Constant &array, const Constant &maxAbs, Rounding rounding) + : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {}; + void operator()(Scalar &element, const ConstantSubscripts &at) { + // Kahan summation of scaled elements + auto scale{maxAbs_.At(maxAbsAt_)}; + if (scale.IsZero()) { + // If maxAbs is zero, so are all elements, and result + element = scale; + } else { + auto item{array_.At(at)}; + auto scaled{item.Divide(scale).value}; + auto square{item.Multiply(scaled).value}; + auto next{square.Add(correction_, rounding_)}; + overflow_ |= next.flags.test(RealFlag::Overflow); + auto sum{element.Add(next.value, rounding_)}; + overflow_ |= sum.flags.test(RealFlag::Overflow); + correction_ = sum.value.Subtract(element, rounding_) + .value.Subtract(next.value, rounding_) + .value; + element = sum.value; + } + } + bool overflow() const { return overflow_; } + void Done(Scalar &result) { + auto corrected{result.Add(correction_, rounding_)}; + overflow_ |= corrected.flags.test(RealFlag::Overflow); + correction_ = Scalar{}; + auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))}; + maxAbs_.IncrementSubscripts(maxAbsAt_); + overflow_ |= rescaled.flags.test(RealFlag::Overflow); + result = rescaled.value.SQRT().value; + } + +private: + const Constant &array_; + const Constant &maxAbs_; + const Rounding rounding_; + bool overflow_{false}; + Scalar correction_{}; + ConstantSubscripts maxAbsAt_{maxAbs_.lbounds()}; +}; + +template +static Expr> FoldNorm2(FoldingContext &context, + FunctionRef> &&funcRef) { + using T = Type; + using Element = typename Constant::Element; + std::optional dim; + const Element identity{}; + if (std::optional> array{ + ProcessReductionArgs(context, funcRef.arguments(), dim, identity, + /*X=*/0, /*DIM=*/1)}) { + MaxvalMinvalAccumulator maxAbsAccumulator{ + RelationalOperator::GT, context, *array}; + Constant maxAbs{ + DoReduction(*array, dim, identity, maxAbsAccumulator)}; + Norm2Accumulator norm2Accumulator{ + *array, maxAbs, context.targetCharacteristics().roundingMode()}; + Constant result{DoReduction(*array, dim, identity, norm2Accumulator)}; + if (norm2Accumulator.overflow()) { + context.messages().Say( + "NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND); + } + return Expr{std::move(result)}; + } + return Expr{std::move(funcRef)}; +} + template Expr> FoldIntrinsicFunction( FoldingContext &context, @@ -238,6 +312,8 @@ Expr> FoldIntrinsicFunction( }, sExpr->u); } + } else if (name == "norm2") { + return FoldNorm2(context, std::move(funcRef)); } else if (name == "product") { auto one{Scalar::FromInteger(value::Integer<8>{1}).value}; return FoldProduct(context, std::move(funcRef), one); @@ -354,7 +430,7 @@ Expr> FoldIntrinsicFunction( return result.value; })); } - // TODO: dot_product, matmul, norm2 + // TODO: matmul return Expr{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h index b76cecffaf1c..cff7f54c60d9 100644 --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -6,8 +6,6 @@ // //===----------------------------------------------------------------------===// -// TODO: NORM2, PARITY - #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_ #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_ @@ -77,7 +75,8 @@ static Expr FoldDotProduct( overflow |= next.overflow; sum = std::move(next.value); } - } else { // T::category == TypeCategory::Real + } else { + static_assert(T::category == TypeCategory::Real); Expr products{ Fold(context, Expr{Constant{*va}} * Expr{Constant{*vb}})}; Constant &cProducts{DEREF(UnwrapConstantValue(products))}; @@ -172,7 +171,8 @@ static std::optional> ProcessReductionArgs(FoldingContext &context, } // Generalized reduction to an array of one dimension fewer (w/ DIM=) -// or to a scalar (w/o DIM=). +// or to a scalar (w/o DIM=). The ACCUMULATOR type must define +// operator()(Scalar &, const ConstantSubscripts &) and Done(Scalar &). template static Constant DoReduction(const Constant &array, std::optional &dim, const Scalar &identity, @@ -193,6 +193,7 @@ static Constant DoReduction(const Constant &array, for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt) { accumulator(elements.back(), at); } + accumulator.Done(elements.back()); } } else { // no DIM=, result is scalar elements.push_back(identity); @@ -200,6 +201,7 @@ static Constant DoReduction(const Constant &array, IncrementSubscripts(at, array.shape())) { accumulator(elements.back(), at); } + accumulator.Done(elements.back()); } if constexpr (T::category == TypeCategory::Character) { return {static_cast(identity.size()), @@ -210,58 +212,85 @@ static Constant DoReduction(const Constant &array, } // MAXVAL & MINVAL +template class MaxvalMinvalAccumulator { +public: + MaxvalMinvalAccumulator( + RelationalOperator opr, FoldingContext &context, const Constant &array) + : opr_{opr}, context_{context}, array_{array} {}; + void operator()(Scalar &element, const ConstantSubscripts &at) const { + auto aAt{array_.At(at)}; + if constexpr (ABS) { + aAt = aAt.ABS(); + } + Expr test{PackageRelation( + opr_, Expr{Constant{aAt}}, Expr{Constant{element}})}; + auto folded{GetScalarConstantValue( + test.Rewrite(context_, std::move(test)))}; + CHECK(folded.has_value()); + if (folded->IsTrue()) { + element = array_.At(at); + } + } + void Done(Scalar &) const {} + +private: + RelationalOperator opr_; + FoldingContext &context_; + const Constant &array_; +}; + template static Expr FoldMaxvalMinval(FoldingContext &context, FunctionRef &&ref, RelationalOperator opr, const Scalar &identity) { static_assert(T::category == TypeCategory::Integer || T::category == TypeCategory::Real || T::category == TypeCategory::Character); - using Element = Scalar; std::optional dim; if (std::optional> array{ ProcessReductionArgs(context, ref.arguments(), dim, identity, /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { - auto accumulator{[&](Element &element, const ConstantSubscripts &at) { - Expr test{PackageRelation(opr, - Expr{Constant{array->At(at)}}, Expr{Constant{element}})}; - auto folded{GetScalarConstantValue( - test.Rewrite(context, std::move(test)))}; - CHECK(folded.has_value()); - if (folded->IsTrue()) { - element = array->At(at); - } - }}; + MaxvalMinvalAccumulator accumulator{opr, context, *array}; return Expr{DoReduction(*array, dim, identity, accumulator)}; } return Expr{std::move(ref)}; } // PRODUCT +template class ProductAccumulator { +public: + ProductAccumulator(const Constant &array) : array_{array} {} + void operator()(Scalar &element, const ConstantSubscripts &at) { + if constexpr (T::category == TypeCategory::Integer) { + auto prod{element.MultiplySigned(array_.At(at))}; + overflow_ |= prod.SignedMultiplicationOverflowed(); + element = prod.lower; + } else { // Real & Complex + auto prod{element.Multiply(array_.At(at))}; + overflow_ |= prod.flags.test(RealFlag::Overflow); + element = prod.value; + } + } + bool overflow() const { return overflow_; } + void Done(Scalar &) const {} + +private: + const Constant &array_; + bool overflow_{false}; +}; + template static Expr FoldProduct( FoldingContext &context, FunctionRef &&ref, Scalar identity) { static_assert(T::category == TypeCategory::Integer || T::category == TypeCategory::Real || T::category == TypeCategory::Complex); - using Element = typename Constant::Element; std::optional dim; if (std::optional> array{ ProcessReductionArgs(context, ref.arguments(), dim, identity, /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { - bool overflow{false}; - auto accumulator{[&](Element &element, const ConstantSubscripts &at) { - if constexpr (T::category == TypeCategory::Integer) { - auto prod{element.MultiplySigned(array->At(at))}; - overflow |= prod.SignedMultiplicationOverflowed(); - element = prod.lower; - } else { // Real & Complex - auto prod{element.Multiply(array->At(at))}; - overflow |= prod.flags.test(RealFlag::Overflow); - element = prod.value; - } - }}; + ProductAccumulator accumulator{*array}; auto result{Expr{DoReduction(*array, dim, identity, accumulator)}}; - if (overflow) { + if (accumulator.overflow()) { context.messages().Say( "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran()); } @@ -271,6 +300,46 @@ static Expr FoldProduct( } // SUM +template class SumAccumulator { + using Element = typename Constant::Element; + +public: + SumAccumulator(const Constant &array, Rounding rounding) + : array_{array}, rounding_{rounding} {} + void operator()(Element &element, const ConstantSubscripts &at) { + if constexpr (T::category == TypeCategory::Integer) { + auto sum{element.AddSigned(array_.At(at))}; + overflow_ |= sum.overflow; + element = sum.value; + } else { // Real & Complex: use Kahan summation + auto next{array_.At(at).Add(correction_, rounding_)}; + overflow_ |= next.flags.test(RealFlag::Overflow); + auto sum{element.Add(next.value, rounding_)}; + overflow_ |= sum.flags.test(RealFlag::Overflow); + // correction = (sum - element) - next; algebraically zero + correction_ = sum.value.Subtract(element, rounding_) + .value.Subtract(next.value, rounding_) + .value; + element = sum.value; + } + } + bool overflow() const { return overflow_; } + void Done([[maybe_unused]] Element &element) { + if constexpr (T::category != TypeCategory::Integer) { + auto corrected{element.Add(correction_, rounding_)}; + overflow_ |= corrected.flags.test(RealFlag::Overflow); + correction_ = Scalar{}; + element = corrected.value; + } + } + +private: + const Constant &array_; + Rounding rounding_; + bool overflow_{false}; + Element correction_{}; +}; + template static Expr FoldSum(FoldingContext &context, FunctionRef &&ref) { static_assert(T::category == TypeCategory::Integer || @@ -278,31 +347,14 @@ static Expr FoldSum(FoldingContext &context, FunctionRef &&ref) { T::category == TypeCategory::Complex); using Element = typename Constant::Element; std::optional dim; - Element identity{}, correction{}; + Element identity{}; if (std::optional> array{ ProcessReductionArgs(context, ref.arguments(), dim, identity, /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { - bool overflow{false}; - auto accumulator{[&](Element &element, const ConstantSubscripts &at) { - if constexpr (T::category == TypeCategory::Integer) { - auto sum{element.AddSigned(array->At(at))}; - overflow |= sum.overflow; - element = sum.value; - } else { // Real & Complex: use Kahan summation - const auto &rounding{context.targetCharacteristics().roundingMode()}; - auto next{array->At(at).Add(correction, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto sum{element.Add(next.value, rounding)}; - overflow |= sum.flags.test(RealFlag::Overflow); - // correction = (sum - element) - next; algebraically zero - correction = sum.value.Subtract(element, rounding) - .value.Subtract(next.value, rounding) - .value; - element = sum.value; - } - }}; + SumAccumulator accumulator{ + *array, context.targetCharacteristics().roundingMode()}; auto result{Expr{DoReduction(*array, dim, identity, accumulator)}}; - if (overflow) { + if (accumulator.overflow()) { context.messages().Say( "SUM() of %s data overflowed"_warn_en_US, T::AsFortran()); } @@ -311,5 +363,21 @@ static Expr FoldSum(FoldingContext &context, FunctionRef &&ref) { return Expr{std::move(ref)}; } +// Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY +template class OperationAccumulator { +public: + OperationAccumulator(const Constant &array, + Scalar (Scalar::*operation)(const Scalar &) const) + : array_{array}, operation_{operation} {} + void operator()(Scalar &element, const ConstantSubscripts &at) { + element = (element.*operation_)(array_.At(at)); + } + void Done(Scalar &) const {} + +private: + const Constant &array_; + Scalar (Scalar::*operation_)(const Scalar &) const; +}; + } // namespace Fortran::evaluate #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_ diff --git a/flang/test/Evaluate/fold-norm2.f90 b/flang/test/Evaluate/fold-norm2.f90 new file mode 100644 index 000000000000..30d5289b5a6e --- /dev/null +++ b/flang/test/Evaluate/fold-norm2.f90 @@ -0,0 +1,29 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of NORM2(), F'2023 16.9.153 +module m + ! Examples from the standard + logical, parameter :: test_ex1 = norm2([3.,4.]) == 5. + real, parameter :: ex2(2,2) = reshape([1.,3.,2.,4.],[2,2]) + real, parameter :: ex2_norm2_1(2) = norm2(ex2, dim=1) + real, parameter :: ex2_1(2) = [3.162277698516845703125,4.472136020660400390625] + logical, parameter :: test_ex2_1 = all(ex2_norm2_1 == ex2_1) + real, parameter :: ex2_norm2_2(2) = norm2(ex2, dim=2) + real, parameter :: ex2_2(2) = [2.2360680103302001953125,5.] + logical, parameter :: test_ex2_2 = all(ex2_norm2_2 == ex2_2) + ! 0 3 6 9 + ! 1 4 7 10 + ! 2 5 8 11 + integer, parameter :: dp = kind(0.d0) + real(dp), parameter :: a(3,4) = & + reshape([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape(a)) + real(dp), parameter :: nAll = norm2(a) + real(dp), parameter :: check_nAll = sqrt(sum(a * a)) + logical, parameter :: test_all = nAll == check_nAll + real(dp), parameter :: norms1(4) = norm2(a, dim=1) + real(dp), parameter :: check_norms1(4) = sqrt(sum(a * a, dim=1)) + logical, parameter :: test_norms1 = all(norms1 == check_norms1) + real(dp), parameter :: norms2(3) = norm2(a, dim=2) + real(dp), parameter :: check_norms2(3) = sqrt(sum(a * a, dim=2)) + logical, parameter :: test_norms2 = all(norms2 == check_norms2) + logical, parameter :: test_normZ = norm2([0.,0.,0.]) == 0. +end