[flang] Fold NORM2() (#66240)

Fold references to the (relatively new) intrinsic function NORM2 at
compilation time when the argument(s) are all constants. (Getting this
done right involved some changes to the API of the accumulator function
objects used by the DoReduction<> template, which rippled through some
other reduction function folding code.)
This commit is contained in:
Peter Klausler 2023-09-18 08:58:19 -07:00 committed by GitHub
parent 93e0658a83
commit 39f1860dcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 248 additions and 68 deletions

View File

@ -264,6 +264,26 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
}
// COUNT()
template <typename T, int MASK_KIND> class CountAccumulator {
using MaskT = Type<TypeCategory::Logical, MASK_KIND>;
public:
CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {}
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
if (mask_.At(at).IsTrue()) {
auto incremented{element.AddSigned(Scalar<T>{1})};
overflow_ |= incremented.overflow;
element = incremented.value;
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &) const {}
private:
const Constant<MaskT> &mask_;
bool overflow_{false};
};
template <typename T, int maskKind>
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
using LogicalResult = Type<TypeCategory::Logical, maskKind>;
@ -274,17 +294,9 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
std::optional<int> dim;
if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
bool overflow{false};
auto accumulator{
[&mask, &overflow](Scalar<T> &element, const ConstantSubscripts &at) {
if (mask->At(at).IsTrue()) {
auto incremented{element.AddSigned(Scalar<T>{1})};
overflow |= incremented.overflow;
element = incremented.value;
}
}};
CountAccumulator<T, maskKind> accumulator{*mask};
Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
if (overflow) {
if (accumulator.overflow()) {
context.messages().Say(
"Result of intrinsic function COUNT overflows its result type"_warn_en_US);
}
@ -513,9 +525,7 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
element = (element.*operation)(array->At(at));
}};
OperationAccumulator<T> accumulator{*array, operation};
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};

View File

@ -28,14 +28,11 @@ static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
Scalar<T> identity) {
static_assert(T::category == TypeCategory::Logical);
using Element = Scalar<T>;
std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
element = (element.*operation)(array->At(at));
}};
OperationAccumulator accumulator{*array, operation};
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};

View File

@ -43,6 +43,80 @@ static Expr<T> FoldTransformationalBessel(
return Expr<T>{std::move(funcRef)};
}
// NORM2
template <int KIND> class Norm2Accumulator {
using T = Type<TypeCategory::Real, KIND>;
public:
Norm2Accumulator(
const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
// Kahan summation of scaled elements
auto scale{maxAbs_.At(maxAbsAt_)};
if (scale.IsZero()) {
// If maxAbs is zero, so are all elements, and result
element = scale;
} else {
auto item{array_.At(at)};
auto scaled{item.Divide(scale).value};
auto square{item.Multiply(scaled).value};
auto next{square.Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &result) {
auto corrected{result.Add(correction_, rounding_)};
overflow_ |= corrected.flags.test(RealFlag::Overflow);
correction_ = Scalar<T>{};
auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))};
maxAbs_.IncrementSubscripts(maxAbsAt_);
overflow_ |= rescaled.flags.test(RealFlag::Overflow);
result = rescaled.value.SQRT().value;
}
private:
const Constant<T> &array_;
const Constant<T> &maxAbs_;
const Rounding rounding_;
bool overflow_{false};
Scalar<T> correction_{};
ConstantSubscripts maxAbsAt_{maxAbs_.lbounds()};
};
template <int KIND>
static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context,
FunctionRef<Type<TypeCategory::Real, KIND>> &&funcRef) {
using T = Type<TypeCategory::Real, KIND>;
using Element = typename Constant<T>::Element;
std::optional<int> dim;
const Element identity{};
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, funcRef.arguments(), dim, identity,
/*X=*/0, /*DIM=*/1)}) {
MaxvalMinvalAccumulator<T, /*ABS=*/true> maxAbsAccumulator{
RelationalOperator::GT, context, *array};
Constant<T> maxAbs{
DoReduction<T>(*array, dim, identity, maxAbsAccumulator)};
Norm2Accumulator norm2Accumulator{
*array, maxAbs, context.targetCharacteristics().roundingMode()};
Constant<T> result{DoReduction<T>(*array, dim, identity, norm2Accumulator)};
if (norm2Accumulator.overflow()) {
context.messages().Say(
"NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND);
}
return Expr<T>{std::move(result)};
}
return Expr<T>{std::move(funcRef)};
}
template <int KIND>
Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
FoldingContext &context,
@ -238,6 +312,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
},
sExpr->u);
}
} else if (name == "norm2") {
return FoldNorm2<T::kind>(context, std::move(funcRef));
} else if (name == "product") {
auto one{Scalar<T>::FromInteger(value::Integer<8>{1}).value};
return FoldProduct<T>(context, std::move(funcRef), one);
@ -354,7 +430,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return result.value;
}));
}
// TODO: dot_product, matmul, norm2
// TODO: matmul
return Expr<T>{std::move(funcRef)};
}

View File

@ -6,8 +6,6 @@
//
//===----------------------------------------------------------------------===//
// TODO: NORM2, PARITY
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
@ -77,7 +75,8 @@ static Expr<T> FoldDotProduct(
overflow |= next.overflow;
sum = std::move(next.value);
}
} else { // T::category == TypeCategory::Real
} else {
static_assert(T::category == TypeCategory::Real);
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
@ -172,7 +171,8 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
}
// Generalized reduction to an array of one dimension fewer (w/ DIM=)
// or to a scalar (w/o DIM=).
// or to a scalar (w/o DIM=). The ACCUMULATOR type must define
// operator()(Scalar<T> &, const ConstantSubscripts &) and Done(Scalar<T> &).
template <typename T, typename ACCUMULATOR, typename ARRAY>
static Constant<T> DoReduction(const Constant<ARRAY> &array,
std::optional<int> &dim, const Scalar<T> &identity,
@ -193,6 +193,7 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt) {
accumulator(elements.back(), at);
}
accumulator.Done(elements.back());
}
} else { // no DIM=, result is scalar
elements.push_back(identity);
@ -200,6 +201,7 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
IncrementSubscripts(at, array.shape())) {
accumulator(elements.back(), at);
}
accumulator.Done(elements.back());
}
if constexpr (T::category == TypeCategory::Character) {
return {static_cast<ConstantSubscript>(identity.size()),
@ -210,58 +212,85 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
}
// MAXVAL & MINVAL
template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
public:
MaxvalMinvalAccumulator(
RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
: opr_{opr}, context_{context}, array_{array} {};
void operator()(Scalar<T> &element, const ConstantSubscripts &at) const {
auto aAt{array_.At(at)};
if constexpr (ABS) {
aAt = aAt.ABS();
}
Expr<LogicalResult> test{PackageRelation(
opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
auto folded{GetScalarConstantValue<LogicalResult>(
test.Rewrite(context_, std::move(test)))};
CHECK(folded.has_value());
if (folded->IsTrue()) {
element = array_.At(at);
}
}
void Done(Scalar<T> &) const {}
private:
RelationalOperator opr_;
FoldingContext &context_;
const Constant<T> &array_;
};
template <typename T>
static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
RelationalOperator opr, const Scalar<T> &identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Character);
using Element = Scalar<T>;
std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
Expr<LogicalResult> test{PackageRelation(opr,
Expr<T>{Constant<T>{array->At(at)}}, Expr<T>{Constant<T>{element}})};
auto folded{GetScalarConstantValue<LogicalResult>(
test.Rewrite(context, std::move(test)))};
CHECK(folded.has_value());
if (folded->IsTrue()) {
element = array->At(at);
}
}};
MaxvalMinvalAccumulator accumulator{opr, context, *array};
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
}
// PRODUCT
template <typename T> class ProductAccumulator {
public:
ProductAccumulator(const Constant<T> &array) : array_{array} {}
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
if constexpr (T::category == TypeCategory::Integer) {
auto prod{element.MultiplySigned(array_.At(at))};
overflow_ |= prod.SignedMultiplicationOverflowed();
element = prod.lower;
} else { // Real & Complex
auto prod{element.Multiply(array_.At(at))};
overflow_ |= prod.flags.test(RealFlag::Overflow);
element = prod.value;
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &) const {}
private:
const Constant<T> &array_;
bool overflow_{false};
};
template <typename T>
static Expr<T> FoldProduct(
FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex);
using Element = typename Constant<T>::Element;
std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
bool overflow{false};
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
if constexpr (T::category == TypeCategory::Integer) {
auto prod{element.MultiplySigned(array->At(at))};
overflow |= prod.SignedMultiplicationOverflowed();
element = prod.lower;
} else { // Real & Complex
auto prod{element.Multiply(array->At(at))};
overflow |= prod.flags.test(RealFlag::Overflow);
element = prod.value;
}
}};
ProductAccumulator accumulator{*array};
auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
if (overflow) {
if (accumulator.overflow()) {
context.messages().Say(
"PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
}
@ -271,6 +300,46 @@ static Expr<T> FoldProduct(
}
// SUM
template <typename T> class SumAccumulator {
using Element = typename Constant<T>::Element;
public:
SumAccumulator(const Constant<T> &array, Rounding rounding)
: array_{array}, rounding_{rounding} {}
void operator()(Element &element, const ConstantSubscripts &at) {
if constexpr (T::category == TypeCategory::Integer) {
auto sum{element.AddSigned(array_.At(at))};
overflow_ |= sum.overflow;
element = sum.value;
} else { // Real & Complex: use Kahan summation
auto next{array_.At(at).Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
// correction = (sum - element) - next; algebraically zero
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
}
}
bool overflow() const { return overflow_; }
void Done([[maybe_unused]] Element &element) {
if constexpr (T::category != TypeCategory::Integer) {
auto corrected{element.Add(correction_, rounding_)};
overflow_ |= corrected.flags.test(RealFlag::Overflow);
correction_ = Scalar<T>{};
element = corrected.value;
}
}
private:
const Constant<T> &array_;
Rounding rounding_;
bool overflow_{false};
Element correction_{};
};
template <typename T>
static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
static_assert(T::category == TypeCategory::Integer ||
@ -278,31 +347,14 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
T::category == TypeCategory::Complex);
using Element = typename Constant<T>::Element;
std::optional<int> dim;
Element identity{}, correction{};
Element identity{};
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
bool overflow{false};
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
if constexpr (T::category == TypeCategory::Integer) {
auto sum{element.AddSigned(array->At(at))};
overflow |= sum.overflow;
element = sum.value;
} else { // Real & Complex: use Kahan summation
const auto &rounding{context.targetCharacteristics().roundingMode()};
auto next{array->At(at).Add(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding)};
overflow |= sum.flags.test(RealFlag::Overflow);
// correction = (sum - element) - next; algebraically zero
correction = sum.value.Subtract(element, rounding)
.value.Subtract(next.value, rounding)
.value;
element = sum.value;
}
}};
SumAccumulator accumulator{
*array, context.targetCharacteristics().roundingMode()};
auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
if (overflow) {
if (accumulator.overflow()) {
context.messages().Say(
"SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
}
@ -311,5 +363,21 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
return Expr<T>{std::move(ref)};
}
// Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
template <typename T> class OperationAccumulator {
public:
OperationAccumulator(const Constant<T> &array,
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
: array_{array}, operation_{operation} {}
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
element = (element.*operation_)(array_.At(at));
}
void Done(Scalar<T> &) const {}
private:
const Constant<T> &array_;
Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
};
} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_

View File

@ -0,0 +1,29 @@
! RUN: %python %S/test_folding.py %s %flang_fc1
! Tests folding of NORM2(), F'2023 16.9.153
module m
! Examples from the standard
logical, parameter :: test_ex1 = norm2([3.,4.]) == 5.
real, parameter :: ex2(2,2) = reshape([1.,3.,2.,4.],[2,2])
real, parameter :: ex2_norm2_1(2) = norm2(ex2, dim=1)
real, parameter :: ex2_1(2) = [3.162277698516845703125,4.472136020660400390625]
logical, parameter :: test_ex2_1 = all(ex2_norm2_1 == ex2_1)
real, parameter :: ex2_norm2_2(2) = norm2(ex2, dim=2)
real, parameter :: ex2_2(2) = [2.2360680103302001953125,5.]
logical, parameter :: test_ex2_2 = all(ex2_norm2_2 == ex2_2)
! 0 3 6 9
! 1 4 7 10
! 2 5 8 11
integer, parameter :: dp = kind(0.d0)
real(dp), parameter :: a(3,4) = &
reshape([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape(a))
real(dp), parameter :: nAll = norm2(a)
real(dp), parameter :: check_nAll = sqrt(sum(a * a))
logical, parameter :: test_all = nAll == check_nAll
real(dp), parameter :: norms1(4) = norm2(a, dim=1)
real(dp), parameter :: check_norms1(4) = sqrt(sum(a * a, dim=1))
logical, parameter :: test_norms1 = all(norms1 == check_norms1)
real(dp), parameter :: norms2(3) = norm2(a, dim=2)
real(dp), parameter :: check_norms2(3) = sqrt(sum(a * a, dim=2))
logical, parameter :: test_norms2 = all(norms2 == check_norms2)
logical, parameter :: test_normZ = norm2([0.,0.,0.]) == 0.
end