mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-14 03:29:57 +00:00
[flang] Fold NORM2() (#66240)
Fold references to the (relatively new) intrinsic function NORM2 at compilation time when the argument(s) are all constants. (Getting this done right involved some changes to the API of the accumulator function objects used by the DoReduction<> template, which rippled through some other reduction function folding code.)
This commit is contained in:
parent
93e0658a83
commit
39f1860dcc
@ -264,6 +264,26 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// COUNT()
|
// COUNT()
|
||||||
|
template <typename T, int MASK_KIND> class CountAccumulator {
|
||||||
|
using MaskT = Type<TypeCategory::Logical, MASK_KIND>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {}
|
||||||
|
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
|
||||||
|
if (mask_.At(at).IsTrue()) {
|
||||||
|
auto incremented{element.AddSigned(Scalar<T>{1})};
|
||||||
|
overflow_ |= incremented.overflow;
|
||||||
|
element = incremented.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool overflow() const { return overflow_; }
|
||||||
|
void Done(Scalar<T> &) const {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const Constant<MaskT> &mask_;
|
||||||
|
bool overflow_{false};
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, int maskKind>
|
template <typename T, int maskKind>
|
||||||
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
|
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
|
||||||
using LogicalResult = Type<TypeCategory::Logical, maskKind>;
|
using LogicalResult = Type<TypeCategory::Logical, maskKind>;
|
||||||
@ -274,17 +294,9 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
|
|||||||
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
|
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
|
||||||
std::optional<int> dim;
|
std::optional<int> dim;
|
||||||
if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
|
if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
|
||||||
bool overflow{false};
|
CountAccumulator<T, maskKind> accumulator{*mask};
|
||||||
auto accumulator{
|
|
||||||
[&mask, &overflow](Scalar<T> &element, const ConstantSubscripts &at) {
|
|
||||||
if (mask->At(at).IsTrue()) {
|
|
||||||
auto incremented{element.AddSigned(Scalar<T>{1})};
|
|
||||||
overflow |= incremented.overflow;
|
|
||||||
element = incremented.value;
|
|
||||||
}
|
|
||||||
}};
|
|
||||||
Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
|
Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
|
||||||
if (overflow) {
|
if (accumulator.overflow()) {
|
||||||
context.messages().Say(
|
context.messages().Say(
|
||||||
"Result of intrinsic function COUNT overflows its result type"_warn_en_US);
|
"Result of intrinsic function COUNT overflows its result type"_warn_en_US);
|
||||||
}
|
}
|
||||||
@ -513,9 +525,7 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
|
|||||||
if (std::optional<Constant<T>> array{
|
if (std::optional<Constant<T>> array{
|
||||||
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
||||||
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
|
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
|
||||||
auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
|
OperationAccumulator<T> accumulator{*array, operation};
|
||||||
element = (element.*operation)(array->At(at));
|
|
||||||
}};
|
|
||||||
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
|
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
|
||||||
}
|
}
|
||||||
return Expr<T>{std::move(ref)};
|
return Expr<T>{std::move(ref)};
|
||||||
|
@ -28,14 +28,11 @@ static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
|
|||||||
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
|
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
|
||||||
Scalar<T> identity) {
|
Scalar<T> identity) {
|
||||||
static_assert(T::category == TypeCategory::Logical);
|
static_assert(T::category == TypeCategory::Logical);
|
||||||
using Element = Scalar<T>;
|
|
||||||
std::optional<int> dim;
|
std::optional<int> dim;
|
||||||
if (std::optional<Constant<T>> array{
|
if (std::optional<Constant<T>> array{
|
||||||
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
||||||
/*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
|
/*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
|
||||||
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
|
OperationAccumulator accumulator{*array, operation};
|
||||||
element = (element.*operation)(array->At(at));
|
|
||||||
}};
|
|
||||||
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
|
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
|
||||||
}
|
}
|
||||||
return Expr<T>{std::move(ref)};
|
return Expr<T>{std::move(ref)};
|
||||||
|
@ -43,6 +43,80 @@ static Expr<T> FoldTransformationalBessel(
|
|||||||
return Expr<T>{std::move(funcRef)};
|
return Expr<T>{std::move(funcRef)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NORM2
|
||||||
|
template <int KIND> class Norm2Accumulator {
|
||||||
|
using T = Type<TypeCategory::Real, KIND>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Norm2Accumulator(
|
||||||
|
const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
|
||||||
|
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
|
||||||
|
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
|
||||||
|
// Kahan summation of scaled elements
|
||||||
|
auto scale{maxAbs_.At(maxAbsAt_)};
|
||||||
|
if (scale.IsZero()) {
|
||||||
|
// If maxAbs is zero, so are all elements, and result
|
||||||
|
element = scale;
|
||||||
|
} else {
|
||||||
|
auto item{array_.At(at)};
|
||||||
|
auto scaled{item.Divide(scale).value};
|
||||||
|
auto square{item.Multiply(scaled).value};
|
||||||
|
auto next{square.Add(correction_, rounding_)};
|
||||||
|
overflow_ |= next.flags.test(RealFlag::Overflow);
|
||||||
|
auto sum{element.Add(next.value, rounding_)};
|
||||||
|
overflow_ |= sum.flags.test(RealFlag::Overflow);
|
||||||
|
correction_ = sum.value.Subtract(element, rounding_)
|
||||||
|
.value.Subtract(next.value, rounding_)
|
||||||
|
.value;
|
||||||
|
element = sum.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool overflow() const { return overflow_; }
|
||||||
|
void Done(Scalar<T> &result) {
|
||||||
|
auto corrected{result.Add(correction_, rounding_)};
|
||||||
|
overflow_ |= corrected.flags.test(RealFlag::Overflow);
|
||||||
|
correction_ = Scalar<T>{};
|
||||||
|
auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))};
|
||||||
|
maxAbs_.IncrementSubscripts(maxAbsAt_);
|
||||||
|
overflow_ |= rescaled.flags.test(RealFlag::Overflow);
|
||||||
|
result = rescaled.value.SQRT().value;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const Constant<T> &array_;
|
||||||
|
const Constant<T> &maxAbs_;
|
||||||
|
const Rounding rounding_;
|
||||||
|
bool overflow_{false};
|
||||||
|
Scalar<T> correction_{};
|
||||||
|
ConstantSubscripts maxAbsAt_{maxAbs_.lbounds()};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int KIND>
|
||||||
|
static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context,
|
||||||
|
FunctionRef<Type<TypeCategory::Real, KIND>> &&funcRef) {
|
||||||
|
using T = Type<TypeCategory::Real, KIND>;
|
||||||
|
using Element = typename Constant<T>::Element;
|
||||||
|
std::optional<int> dim;
|
||||||
|
const Element identity{};
|
||||||
|
if (std::optional<Constant<T>> array{
|
||||||
|
ProcessReductionArgs<T>(context, funcRef.arguments(), dim, identity,
|
||||||
|
/*X=*/0, /*DIM=*/1)}) {
|
||||||
|
MaxvalMinvalAccumulator<T, /*ABS=*/true> maxAbsAccumulator{
|
||||||
|
RelationalOperator::GT, context, *array};
|
||||||
|
Constant<T> maxAbs{
|
||||||
|
DoReduction<T>(*array, dim, identity, maxAbsAccumulator)};
|
||||||
|
Norm2Accumulator norm2Accumulator{
|
||||||
|
*array, maxAbs, context.targetCharacteristics().roundingMode()};
|
||||||
|
Constant<T> result{DoReduction<T>(*array, dim, identity, norm2Accumulator)};
|
||||||
|
if (norm2Accumulator.overflow()) {
|
||||||
|
context.messages().Say(
|
||||||
|
"NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND);
|
||||||
|
}
|
||||||
|
return Expr<T>{std::move(result)};
|
||||||
|
}
|
||||||
|
return Expr<T>{std::move(funcRef)};
|
||||||
|
}
|
||||||
|
|
||||||
template <int KIND>
|
template <int KIND>
|
||||||
Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
||||||
FoldingContext &context,
|
FoldingContext &context,
|
||||||
@ -238,6 +312,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
|||||||
},
|
},
|
||||||
sExpr->u);
|
sExpr->u);
|
||||||
}
|
}
|
||||||
|
} else if (name == "norm2") {
|
||||||
|
return FoldNorm2<T::kind>(context, std::move(funcRef));
|
||||||
} else if (name == "product") {
|
} else if (name == "product") {
|
||||||
auto one{Scalar<T>::FromInteger(value::Integer<8>{1}).value};
|
auto one{Scalar<T>::FromInteger(value::Integer<8>{1}).value};
|
||||||
return FoldProduct<T>(context, std::move(funcRef), one);
|
return FoldProduct<T>(context, std::move(funcRef), one);
|
||||||
@ -354,7 +430,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
|
|||||||
return result.value;
|
return result.value;
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
// TODO: dot_product, matmul, norm2
|
// TODO: matmul
|
||||||
return Expr<T>{std::move(funcRef)};
|
return Expr<T>{std::move(funcRef)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,8 +6,6 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// TODO: NORM2, PARITY
|
|
||||||
|
|
||||||
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
|
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
|
||||||
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
|
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
|
||||||
|
|
||||||
@ -77,7 +75,8 @@ static Expr<T> FoldDotProduct(
|
|||||||
overflow |= next.overflow;
|
overflow |= next.overflow;
|
||||||
sum = std::move(next.value);
|
sum = std::move(next.value);
|
||||||
}
|
}
|
||||||
} else { // T::category == TypeCategory::Real
|
} else {
|
||||||
|
static_assert(T::category == TypeCategory::Real);
|
||||||
Expr<T> products{
|
Expr<T> products{
|
||||||
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
|
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
|
||||||
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
|
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
|
||||||
@ -172,7 +171,8 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generalized reduction to an array of one dimension fewer (w/ DIM=)
|
// Generalized reduction to an array of one dimension fewer (w/ DIM=)
|
||||||
// or to a scalar (w/o DIM=).
|
// or to a scalar (w/o DIM=). The ACCUMULATOR type must define
|
||||||
|
// operator()(Scalar<T> &, const ConstantSubscripts &) and Done(Scalar<T> &).
|
||||||
template <typename T, typename ACCUMULATOR, typename ARRAY>
|
template <typename T, typename ACCUMULATOR, typename ARRAY>
|
||||||
static Constant<T> DoReduction(const Constant<ARRAY> &array,
|
static Constant<T> DoReduction(const Constant<ARRAY> &array,
|
||||||
std::optional<int> &dim, const Scalar<T> &identity,
|
std::optional<int> &dim, const Scalar<T> &identity,
|
||||||
@ -193,6 +193,7 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
|
|||||||
for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt) {
|
for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt) {
|
||||||
accumulator(elements.back(), at);
|
accumulator(elements.back(), at);
|
||||||
}
|
}
|
||||||
|
accumulator.Done(elements.back());
|
||||||
}
|
}
|
||||||
} else { // no DIM=, result is scalar
|
} else { // no DIM=, result is scalar
|
||||||
elements.push_back(identity);
|
elements.push_back(identity);
|
||||||
@ -200,6 +201,7 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
|
|||||||
IncrementSubscripts(at, array.shape())) {
|
IncrementSubscripts(at, array.shape())) {
|
||||||
accumulator(elements.back(), at);
|
accumulator(elements.back(), at);
|
||||||
}
|
}
|
||||||
|
accumulator.Done(elements.back());
|
||||||
}
|
}
|
||||||
if constexpr (T::category == TypeCategory::Character) {
|
if constexpr (T::category == TypeCategory::Character) {
|
||||||
return {static_cast<ConstantSubscript>(identity.size()),
|
return {static_cast<ConstantSubscript>(identity.size()),
|
||||||
@ -210,58 +212,85 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MAXVAL & MINVAL
|
// MAXVAL & MINVAL
|
||||||
|
template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
|
||||||
|
public:
|
||||||
|
MaxvalMinvalAccumulator(
|
||||||
|
RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
|
||||||
|
: opr_{opr}, context_{context}, array_{array} {};
|
||||||
|
void operator()(Scalar<T> &element, const ConstantSubscripts &at) const {
|
||||||
|
auto aAt{array_.At(at)};
|
||||||
|
if constexpr (ABS) {
|
||||||
|
aAt = aAt.ABS();
|
||||||
|
}
|
||||||
|
Expr<LogicalResult> test{PackageRelation(
|
||||||
|
opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
|
||||||
|
auto folded{GetScalarConstantValue<LogicalResult>(
|
||||||
|
test.Rewrite(context_, std::move(test)))};
|
||||||
|
CHECK(folded.has_value());
|
||||||
|
if (folded->IsTrue()) {
|
||||||
|
element = array_.At(at);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Done(Scalar<T> &) const {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
RelationalOperator opr_;
|
||||||
|
FoldingContext &context_;
|
||||||
|
const Constant<T> &array_;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
|
static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
|
||||||
RelationalOperator opr, const Scalar<T> &identity) {
|
RelationalOperator opr, const Scalar<T> &identity) {
|
||||||
static_assert(T::category == TypeCategory::Integer ||
|
static_assert(T::category == TypeCategory::Integer ||
|
||||||
T::category == TypeCategory::Real ||
|
T::category == TypeCategory::Real ||
|
||||||
T::category == TypeCategory::Character);
|
T::category == TypeCategory::Character);
|
||||||
using Element = Scalar<T>;
|
|
||||||
std::optional<int> dim;
|
std::optional<int> dim;
|
||||||
if (std::optional<Constant<T>> array{
|
if (std::optional<Constant<T>> array{
|
||||||
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
||||||
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
|
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
|
||||||
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
|
MaxvalMinvalAccumulator accumulator{opr, context, *array};
|
||||||
Expr<LogicalResult> test{PackageRelation(opr,
|
|
||||||
Expr<T>{Constant<T>{array->At(at)}}, Expr<T>{Constant<T>{element}})};
|
|
||||||
auto folded{GetScalarConstantValue<LogicalResult>(
|
|
||||||
test.Rewrite(context, std::move(test)))};
|
|
||||||
CHECK(folded.has_value());
|
|
||||||
if (folded->IsTrue()) {
|
|
||||||
element = array->At(at);
|
|
||||||
}
|
|
||||||
}};
|
|
||||||
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
|
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
|
||||||
}
|
}
|
||||||
return Expr<T>{std::move(ref)};
|
return Expr<T>{std::move(ref)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// PRODUCT
|
// PRODUCT
|
||||||
|
template <typename T> class ProductAccumulator {
|
||||||
|
public:
|
||||||
|
ProductAccumulator(const Constant<T> &array) : array_{array} {}
|
||||||
|
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
|
||||||
|
if constexpr (T::category == TypeCategory::Integer) {
|
||||||
|
auto prod{element.MultiplySigned(array_.At(at))};
|
||||||
|
overflow_ |= prod.SignedMultiplicationOverflowed();
|
||||||
|
element = prod.lower;
|
||||||
|
} else { // Real & Complex
|
||||||
|
auto prod{element.Multiply(array_.At(at))};
|
||||||
|
overflow_ |= prod.flags.test(RealFlag::Overflow);
|
||||||
|
element = prod.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool overflow() const { return overflow_; }
|
||||||
|
void Done(Scalar<T> &) const {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const Constant<T> &array_;
|
||||||
|
bool overflow_{false};
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static Expr<T> FoldProduct(
|
static Expr<T> FoldProduct(
|
||||||
FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
|
FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
|
||||||
static_assert(T::category == TypeCategory::Integer ||
|
static_assert(T::category == TypeCategory::Integer ||
|
||||||
T::category == TypeCategory::Real ||
|
T::category == TypeCategory::Real ||
|
||||||
T::category == TypeCategory::Complex);
|
T::category == TypeCategory::Complex);
|
||||||
using Element = typename Constant<T>::Element;
|
|
||||||
std::optional<int> dim;
|
std::optional<int> dim;
|
||||||
if (std::optional<Constant<T>> array{
|
if (std::optional<Constant<T>> array{
|
||||||
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
||||||
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
|
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
|
||||||
bool overflow{false};
|
ProductAccumulator accumulator{*array};
|
||||||
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
|
|
||||||
if constexpr (T::category == TypeCategory::Integer) {
|
|
||||||
auto prod{element.MultiplySigned(array->At(at))};
|
|
||||||
overflow |= prod.SignedMultiplicationOverflowed();
|
|
||||||
element = prod.lower;
|
|
||||||
} else { // Real & Complex
|
|
||||||
auto prod{element.Multiply(array->At(at))};
|
|
||||||
overflow |= prod.flags.test(RealFlag::Overflow);
|
|
||||||
element = prod.value;
|
|
||||||
}
|
|
||||||
}};
|
|
||||||
auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
|
auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
|
||||||
if (overflow) {
|
if (accumulator.overflow()) {
|
||||||
context.messages().Say(
|
context.messages().Say(
|
||||||
"PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
|
"PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
|
||||||
}
|
}
|
||||||
@ -271,6 +300,46 @@ static Expr<T> FoldProduct(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SUM
|
// SUM
|
||||||
|
template <typename T> class SumAccumulator {
|
||||||
|
using Element = typename Constant<T>::Element;
|
||||||
|
|
||||||
|
public:
|
||||||
|
SumAccumulator(const Constant<T> &array, Rounding rounding)
|
||||||
|
: array_{array}, rounding_{rounding} {}
|
||||||
|
void operator()(Element &element, const ConstantSubscripts &at) {
|
||||||
|
if constexpr (T::category == TypeCategory::Integer) {
|
||||||
|
auto sum{element.AddSigned(array_.At(at))};
|
||||||
|
overflow_ |= sum.overflow;
|
||||||
|
element = sum.value;
|
||||||
|
} else { // Real & Complex: use Kahan summation
|
||||||
|
auto next{array_.At(at).Add(correction_, rounding_)};
|
||||||
|
overflow_ |= next.flags.test(RealFlag::Overflow);
|
||||||
|
auto sum{element.Add(next.value, rounding_)};
|
||||||
|
overflow_ |= sum.flags.test(RealFlag::Overflow);
|
||||||
|
// correction = (sum - element) - next; algebraically zero
|
||||||
|
correction_ = sum.value.Subtract(element, rounding_)
|
||||||
|
.value.Subtract(next.value, rounding_)
|
||||||
|
.value;
|
||||||
|
element = sum.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool overflow() const { return overflow_; }
|
||||||
|
void Done([[maybe_unused]] Element &element) {
|
||||||
|
if constexpr (T::category != TypeCategory::Integer) {
|
||||||
|
auto corrected{element.Add(correction_, rounding_)};
|
||||||
|
overflow_ |= corrected.flags.test(RealFlag::Overflow);
|
||||||
|
correction_ = Scalar<T>{};
|
||||||
|
element = corrected.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const Constant<T> &array_;
|
||||||
|
Rounding rounding_;
|
||||||
|
bool overflow_{false};
|
||||||
|
Element correction_{};
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
|
static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
|
||||||
static_assert(T::category == TypeCategory::Integer ||
|
static_assert(T::category == TypeCategory::Integer ||
|
||||||
@ -278,31 +347,14 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
|
|||||||
T::category == TypeCategory::Complex);
|
T::category == TypeCategory::Complex);
|
||||||
using Element = typename Constant<T>::Element;
|
using Element = typename Constant<T>::Element;
|
||||||
std::optional<int> dim;
|
std::optional<int> dim;
|
||||||
Element identity{}, correction{};
|
Element identity{};
|
||||||
if (std::optional<Constant<T>> array{
|
if (std::optional<Constant<T>> array{
|
||||||
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
|
||||||
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
|
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
|
||||||
bool overflow{false};
|
SumAccumulator accumulator{
|
||||||
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
|
*array, context.targetCharacteristics().roundingMode()};
|
||||||
if constexpr (T::category == TypeCategory::Integer) {
|
|
||||||
auto sum{element.AddSigned(array->At(at))};
|
|
||||||
overflow |= sum.overflow;
|
|
||||||
element = sum.value;
|
|
||||||
} else { // Real & Complex: use Kahan summation
|
|
||||||
const auto &rounding{context.targetCharacteristics().roundingMode()};
|
|
||||||
auto next{array->At(at).Add(correction, rounding)};
|
|
||||||
overflow |= next.flags.test(RealFlag::Overflow);
|
|
||||||
auto sum{element.Add(next.value, rounding)};
|
|
||||||
overflow |= sum.flags.test(RealFlag::Overflow);
|
|
||||||
// correction = (sum - element) - next; algebraically zero
|
|
||||||
correction = sum.value.Subtract(element, rounding)
|
|
||||||
.value.Subtract(next.value, rounding)
|
|
||||||
.value;
|
|
||||||
element = sum.value;
|
|
||||||
}
|
|
||||||
}};
|
|
||||||
auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
|
auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
|
||||||
if (overflow) {
|
if (accumulator.overflow()) {
|
||||||
context.messages().Say(
|
context.messages().Say(
|
||||||
"SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
|
"SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
|
||||||
}
|
}
|
||||||
@ -311,5 +363,21 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
|
|||||||
return Expr<T>{std::move(ref)};
|
return Expr<T>{std::move(ref)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
|
||||||
|
template <typename T> class OperationAccumulator {
|
||||||
|
public:
|
||||||
|
OperationAccumulator(const Constant<T> &array,
|
||||||
|
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
|
||||||
|
: array_{array}, operation_{operation} {}
|
||||||
|
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
|
||||||
|
element = (element.*operation_)(array_.At(at));
|
||||||
|
}
|
||||||
|
void Done(Scalar<T> &) const {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const Constant<T> &array_;
|
||||||
|
Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace Fortran::evaluate
|
} // namespace Fortran::evaluate
|
||||||
#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
|
#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
|
||||||
|
29
flang/test/Evaluate/fold-norm2.f90
Normal file
29
flang/test/Evaluate/fold-norm2.f90
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
! RUN: %python %S/test_folding.py %s %flang_fc1
|
||||||
|
! Tests folding of NORM2(), F'2023 16.9.153
|
||||||
|
module m
|
||||||
|
! Examples from the standard
|
||||||
|
logical, parameter :: test_ex1 = norm2([3.,4.]) == 5.
|
||||||
|
real, parameter :: ex2(2,2) = reshape([1.,3.,2.,4.],[2,2])
|
||||||
|
real, parameter :: ex2_norm2_1(2) = norm2(ex2, dim=1)
|
||||||
|
real, parameter :: ex2_1(2) = [3.162277698516845703125,4.472136020660400390625]
|
||||||
|
logical, parameter :: test_ex2_1 = all(ex2_norm2_1 == ex2_1)
|
||||||
|
real, parameter :: ex2_norm2_2(2) = norm2(ex2, dim=2)
|
||||||
|
real, parameter :: ex2_2(2) = [2.2360680103302001953125,5.]
|
||||||
|
logical, parameter :: test_ex2_2 = all(ex2_norm2_2 == ex2_2)
|
||||||
|
! 0 3 6 9
|
||||||
|
! 1 4 7 10
|
||||||
|
! 2 5 8 11
|
||||||
|
integer, parameter :: dp = kind(0.d0)
|
||||||
|
real(dp), parameter :: a(3,4) = &
|
||||||
|
reshape([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape(a))
|
||||||
|
real(dp), parameter :: nAll = norm2(a)
|
||||||
|
real(dp), parameter :: check_nAll = sqrt(sum(a * a))
|
||||||
|
logical, parameter :: test_all = nAll == check_nAll
|
||||||
|
real(dp), parameter :: norms1(4) = norm2(a, dim=1)
|
||||||
|
real(dp), parameter :: check_norms1(4) = sqrt(sum(a * a, dim=1))
|
||||||
|
logical, parameter :: test_norms1 = all(norms1 == check_norms1)
|
||||||
|
real(dp), parameter :: norms2(3) = norm2(a, dim=2)
|
||||||
|
real(dp), parameter :: check_norms2(3) = sqrt(sum(a * a, dim=2))
|
||||||
|
logical, parameter :: test_norms2 = all(norms2 == check_norms2)
|
||||||
|
logical, parameter :: test_normZ = norm2([0.,0.,0.]) == 0.
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user