9#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
10#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
12#include "fold-implementation.h"
20 using Element =
typename Constant<T>::Element;
21 auto args{funcRef.arguments()};
22 CHECK(args.size() == 2);
27 CHECK(va->Rank() == 1 && vb->Rank() == 1);
28 if (va->size() != vb->size()) {
29 context.messages().Say(
30 "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
31 va->size(), vb->size());
32 return MakeInvalidIntrinsic(std::move(funcRef));
36 if constexpr (T::category == TypeCategory::Complex) {
37 std::vector<Element> conjugates;
38 for (
const Element &x : va->values()) {
39 conjugates.emplace_back(x.CONJG());
42 std::move(conjugates), ConstantSubscripts{va->shape()}};
45 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
46 [[maybe_unused]] Element correction{};
47 const auto &rounding{context.targetCharacteristics().roundingMode()};
48 for (
const Element &x : cProducts.values()) {
49 if constexpr (useKahanSummation) {
50 auto next{x.Subtract(correction, rounding)};
51 overflow |= next.flags.test(RealFlag::Overflow);
52 auto added{sum.Add(next.value, rounding)};
53 overflow |= added.flags.test(RealFlag::Overflow);
54 correction = added.value.Subtract(sum, rounding)
55 .value.Subtract(next.value, rounding)
57 sum = std::move(added.value);
59 auto added{sum.Add(x, rounding)};
60 overflow |= added.flags.test(RealFlag::Overflow);
61 sum = std::move(added.value);
64 }
else if constexpr (T::category == TypeCategory::Logical) {
65 Expr<T> conjunctions{Fold(context,
68 Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
69 for (
const Element &x : cConjunctions.values()) {
75 }
else if constexpr (T::category == TypeCategory::Integer) {
78 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
79 for (
const Element &x : cProducts.values()) {
80 auto next{sum.AddSigned(x)};
81 overflow |= next.overflow;
82 sum = std::move(next.value);
84 }
else if constexpr (T::category == TypeCategory::Unsigned) {
87 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
88 for (
const Element &x : cProducts.values()) {
89 sum = sum.AddUnsigned(x).value;
92 static_assert(T::category == TypeCategory::Real);
95 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
96 [[maybe_unused]] Element correction{};
97 const auto &rounding{context.targetCharacteristics().roundingMode()};
98 for (
const Element &x : cProducts.values()) {
99 if constexpr (useKahanSummation) {
100 auto next{x.Subtract(correction, rounding)};
101 overflow |= next.flags.test(RealFlag::Overflow);
102 auto added{sum.Add(next.value, rounding)};
103 overflow |= added.flags.test(RealFlag::Overflow);
104 correction = added.value.Subtract(sum, rounding)
105 .value.Subtract(next.value, rounding)
107 sum = std::move(added.value);
109 auto added{sum.Add(x, rounding)};
110 overflow |= added.flags.test(RealFlag::Overflow);
111 sum = std::move(added.value);
116 context.Warn(common::UsageWarning::FoldingException,
117 "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
122 return Expr<T>{std::move(funcRef)};
127 ActualArguments &, std::optional<int> dimIndex,
int rank);
132 std::optional<ActualArgument> &maskArg,
const ConstantSubscripts &shape,
145static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
146 FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
147 int arrayIndex, std::optional<int> dimIndex = std::nullopt,
148 std::optional<int> maskIndex = std::nullopt) {
152 Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
153 if (!folded || folded->Rank() < 1) {
156 if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
159 std::size_t n{folded->size()};
160 std::vector<Scalar<LogicalResult>> maskElement;
161 if (maskIndex &&
static_cast<std::size_t
>(*maskIndex) < arg.size() &&
164 GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
165 if (
auto scalarMask{origMask->GetScalarValue()}) {
167 std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
169 maskElement = origMask->values();
175 maskElement = std::vector<Scalar<LogicalResult>>(n,
true);
179 std::move(maskElement), ConstantSubscripts{folded->shape()}}};
186template <
typename T,
typename ACCUMULATOR,
typename ARRAY>
189 const Scalar<T> &identity, ACCUMULATOR &accumulator) {
190 ConstantSubscripts at{array.lbounds()};
191 ConstantSubscripts maskAt{mask.lbounds()};
192 std::vector<typename Constant<T>::Element> elements;
193 ConstantSubscripts resultShape;
195 resultShape = array.shape();
196 resultShape.erase(resultShape.begin() + (*dim - 1));
197 ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
198 CHECK(dimExtent == mask.shape().at(*dim - 1));
199 ConstantSubscript &dimAt{at[*dim - 1]};
200 ConstantSubscript dimLbound{dimAt};
201 ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
202 ConstantSubscript maskDimLbound{maskDimAt};
203 for (
auto n{GetSize(resultShape)}; n-- > 0;
204 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
205 elements.push_back(identity);
208 maskDimAt = maskDimLbound;
209 bool firstUnmasked{
true};
210 for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
211 if (mask.At(maskAt).IsTrue()) {
212 accumulator(elements.back(), at, firstUnmasked);
213 firstUnmasked =
false;
216 --dimAt, --maskDimAt;
218 accumulator.Done(elements.back());
221 elements.push_back(identity);
222 bool firstUnmasked{
true};
223 for (
auto n{array.size()}; n-- > 0;
224 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
225 if (mask.At(maskAt).IsTrue()) {
226 accumulator(elements.back(), at, firstUnmasked);
227 firstUnmasked =
false;
230 accumulator.Done(elements.back());
232 if constexpr (T::category == TypeCategory::Character) {
233 return {
static_cast<ConstantSubscript
>(identity.size()),
234 std::move(elements), std::move(resultShape)};
236 return {std::move(elements), std::move(resultShape)};
241template <
typename T,
bool ABS = false>
class MaxvalMinvalAccumulator {
243 MaxvalMinvalAccumulator(
245 : opr_{opr}, context_{context}, array_{array} {};
246 void operator()(Scalar<T> &element,
const ConstantSubscripts &at,
247 [[maybe_unused]]
bool firstUnmasked)
const {
248 auto aAt{array_.At(at)};
252 if constexpr (T::category == TypeCategory::Real) {
253 if (firstUnmasked || element.IsNotANumber()) {
262 auto folded{GetScalarConstantValue<LogicalResult>(
263 test.Rewrite(context_, std::move(test)))};
264 CHECK(folded.has_value());
265 if (folded->IsTrue()) {
269 void Done(Scalar<T> &)
const {}
272 RelationalOperator opr_;
279 RelationalOperator opr,
const Scalar<T> &identity) {
280 static_assert(T::category == TypeCategory::Integer ||
281 T::category == TypeCategory::Unsigned ||
282 T::category == TypeCategory::Real ||
283 T::category == TypeCategory::Character);
284 std::optional<int> dim;
286 ProcessReductionArgs<T>(context, ref.arguments(), dim,
288 MaxvalMinvalAccumulator<T> accumulator{opr, context, arrayAndMask->array};
289 return Expr<T>{DoReduction<T>(
290 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
292 return Expr<T>{std::move(ref)};
296template <
typename T>
class ProductAccumulator {
298 ProductAccumulator(
const Constant<T> &array) : array_{array} {}
300 Scalar<T> &element,
const ConstantSubscripts &at,
bool ) {
301 if constexpr (T::category == TypeCategory::Integer) {
302 auto prod{element.MultiplySigned(array_.At(at))};
303 overflow_ |= prod.SignedMultiplicationOverflowed();
304 element = prod.lower;
305 }
else if constexpr (T::category == TypeCategory::Unsigned) {
306 element = element.MultiplyUnsigned(array_.At(at)).lower;
308 auto prod{element.Multiply(array_.At(at))};
309 overflow_ |= prod.flags.test(RealFlag::Overflow);
310 element = prod.value;
313 bool overflow()
const {
return overflow_; }
314 void Done(Scalar<T> &)
const {}
318 bool overflow_{
false};
324 static_assert(T::category == TypeCategory::Integer ||
325 T::category == TypeCategory::Unsigned ||
326 T::category == TypeCategory::Real ||
327 T::category == TypeCategory::Complex);
328 std::optional<int> dim;
330 ProcessReductionArgs<T>(context, ref.arguments(), dim,
332 ProductAccumulator accumulator{arrayAndMask->array};
333 auto result{Expr<T>{DoReduction<T>(
334 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
335 if (accumulator.overflow()) {
336 context.Warn(common::UsageWarning::FoldingException,
337 "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
341 return Expr<T>{std::move(ref)};
345template <
typename T>
class SumAccumulator {
346 using Element =
typename Constant<T>::Element;
350 : array_{array}, rounding_{rounding} {}
352 Element &element,
const ConstantSubscripts &at,
bool ) {
353 if constexpr (T::category == TypeCategory::Integer) {
354 auto sum{element.AddSigned(array_.At(at))};
355 overflow_ |= sum.overflow;
357 }
else if constexpr (T::category == TypeCategory::Unsigned) {
358 element = element.AddUnsigned(array_.At(at)).value;
360 auto next{array_.At(at).Subtract(correction_, rounding_)};
361 overflow_ |= next.flags.test(RealFlag::Overflow);
362 auto sum{element.Add(next.value, rounding_)};
363 overflow_ |= sum.flags.test(RealFlag::Overflow);
365 correction_ = sum.value.Subtract(element, rounding_)
366 .value.Subtract(next.value, rounding_)
371 bool overflow()
const {
return overflow_; }
372 void Done([[maybe_unused]] Element &element) {
373 if constexpr (T::category != TypeCategory::Integer &&
374 T::category != TypeCategory::Unsigned) {
375 auto corrected{element.Add(correction_, rounding_)};
376 overflow_ |= corrected.flags.test(RealFlag::Overflow);
377 correction_ = Scalar<T>{};
378 element = corrected.value;
385 bool overflow_{
false};
386 Element correction_{};
391 static_assert(T::category == TypeCategory::Integer ||
392 T::category == TypeCategory::Unsigned ||
393 T::category == TypeCategory::Real ||
394 T::category == TypeCategory::Complex);
395 using Element =
typename Constant<T>::Element;
396 std::optional<int> dim;
398 if (std::optional<ArrayAndMask<T>> arrayAndMask{
399 ProcessReductionArgs<T>(context, ref.arguments(), dim,
401 SumAccumulator accumulator{
402 arrayAndMask->array, context.targetCharacteristics().roundingMode()};
403 auto result{
Expr<T>{DoReduction<T>(
404 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
405 if (accumulator.overflow()) {
406 context.Warn(common::UsageWarning::FoldingException,
407 "SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
411 return Expr<T>{std::move(ref)};
415template <
typename T>
class OperationAccumulator {
418 Scalar<T> (Scalar<T>::*operation)(
const Scalar<T> &)
const)
419 : array_{array}, operation_{operation} {}
421 Scalar<T> &element,
const ConstantSubscripts &at,
bool ) {
422 element = (element.*operation_)(array_.At(at));
424 void Done(Scalar<T> &)
const {}
428 Scalar<T> (Scalar<T>::*operation_)(
const Scalar<T> &)
const;
Definition constant.h:147
Definition fold-implementation.h:55
Definition target-rounding.h:18
Definition fold-reduction.h:140
Definition expression.h:379