FLANG
fold-reduction.h
1//===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
10#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
11
12#include "fold-implementation.h"
13
14namespace Fortran::evaluate {
15
16// DOT_PRODUCT
17template <typename T>
18static Expr<T> FoldDotProduct(
19 FoldingContext &context, FunctionRef<T> &&funcRef) {
20 using Element = typename Constant<T>::Element;
21 auto args{funcRef.arguments()};
22 CHECK(args.size() == 2);
23 Folder<T> folder{context};
24 Constant<T> *va{folder.Folding(args[0])};
25 Constant<T> *vb{folder.Folding(args[1])};
26 if (va && vb) {
27 CHECK(va->Rank() == 1 && vb->Rank() == 1);
28 if (va->size() != vb->size()) {
29 context.messages().Say(
30 "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
31 va->size(), vb->size());
32 return MakeInvalidIntrinsic(std::move(funcRef));
33 }
34 Element sum{};
35 bool overflow{false};
36 if constexpr (T::category == TypeCategory::Complex) {
37 std::vector<Element> conjugates;
38 for (const Element &x : va->values()) {
39 conjugates.emplace_back(x.CONJG());
40 }
41 Constant<T> conjgA{
42 std::move(conjugates), ConstantSubscripts{va->shape()}};
43 Expr<T> products{Fold(
44 context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
45 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
46 [[maybe_unused]] Element correction{};
47 const auto &rounding{context.targetCharacteristics().roundingMode()};
48 for (const Element &x : cProducts.values()) {
49 if constexpr (useKahanSummation) {
50 auto next{x.Subtract(correction, rounding)};
51 overflow |= next.flags.test(RealFlag::Overflow);
52 auto added{sum.Add(next.value, rounding)};
53 overflow |= added.flags.test(RealFlag::Overflow);
54 correction = added.value.Subtract(sum, rounding)
55 .value.Subtract(next.value, rounding)
56 .value;
57 sum = std::move(added.value);
58 } else {
59 auto added{sum.Add(x, rounding)};
60 overflow |= added.flags.test(RealFlag::Overflow);
61 sum = std::move(added.value);
62 }
63 }
64 } else if constexpr (T::category == TypeCategory::Logical) {
65 Expr<T> conjunctions{Fold(context,
66 Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
67 Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
68 Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
69 for (const Element &x : cConjunctions.values()) {
70 if (x.IsTrue()) {
71 sum = Element{true};
72 break;
73 }
74 }
75 } else if constexpr (T::category == TypeCategory::Integer) {
76 Expr<T> products{
77 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
78 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
79 for (const Element &x : cProducts.values()) {
80 auto next{sum.AddSigned(x)};
81 overflow |= next.overflow;
82 sum = std::move(next.value);
83 }
84 } else if constexpr (T::category == TypeCategory::Unsigned) {
85 Expr<T> products{
86 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
87 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
88 for (const Element &x : cProducts.values()) {
89 sum = sum.AddUnsigned(x).value;
90 }
91 } else {
92 static_assert(T::category == TypeCategory::Real);
93 Expr<T> products{
94 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
95 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
96 [[maybe_unused]] Element correction{};
97 const auto &rounding{context.targetCharacteristics().roundingMode()};
98 for (const Element &x : cProducts.values()) {
99 if constexpr (useKahanSummation) {
100 auto next{x.Subtract(correction, rounding)};
101 overflow |= next.flags.test(RealFlag::Overflow);
102 auto added{sum.Add(next.value, rounding)};
103 overflow |= added.flags.test(RealFlag::Overflow);
104 correction = added.value.Subtract(sum, rounding)
105 .value.Subtract(next.value, rounding)
106 .value;
107 sum = std::move(added.value);
108 } else {
109 auto added{sum.Add(x, rounding)};
110 overflow |= added.flags.test(RealFlag::Overflow);
111 sum = std::move(added.value);
112 }
113 }
114 }
115 if (overflow &&
116 context.languageFeatures().ShouldWarn(
117 common::UsageWarning::FoldingException)) {
118 context.messages().Say(common::UsageWarning::FoldingException,
119 "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
120 T::AsFortran());
121 }
122 return Expr<T>{Constant<T>{std::move(sum)}};
123 }
124 return Expr<T>{std::move(funcRef)};
125}
126
127// Fold and validate a DIM= argument. Returns false on error.
128bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
129 ActualArguments &, std::optional<int> dimIndex, int rank);
130
131// Fold and validate a MASK= argument. Return null on error, absent MASK=, or
132// non-constant MASK=.
133Constant<LogicalResult> *GetReductionMASK(
134 std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
135 FoldingContext &);
136
137// Common preprocessing for reduction transformational intrinsic function
138// folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
139// and check them. If a MASK= is present, apply it to the array data and
140// substitute replacement values for elements corresponding to .FALSE. in
141// the mask. If the result is present, the intrinsic call can be folded.
142template <typename T> struct ArrayAndMask {
143 Constant<T> array;
145};
146template <typename T>
147static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
148 FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
149 int arrayIndex, std::optional<int> dimIndex = std::nullopt,
150 std::optional<int> maskIndex = std::nullopt) {
151 if (arg.empty()) {
152 return std::nullopt;
153 }
154 Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
155 if (!folded || folded->Rank() < 1) {
156 return std::nullopt;
157 }
158 if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
159 return std::nullopt;
160 }
161 std::size_t n{folded->size()};
162 std::vector<Scalar<LogicalResult>> maskElement;
163 if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
164 arg[*maskIndex]) {
165 if (const Constant<LogicalResult> *origMask{
166 GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
167 if (auto scalarMask{origMask->GetScalarValue()}) {
168 maskElement =
169 std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
170 } else {
171 maskElement = origMask->values();
172 }
173 } else {
174 return std::nullopt;
175 }
176 } else {
177 maskElement = std::vector<Scalar<LogicalResult>>(n, true);
178 }
179 return ArrayAndMask<T>{Constant<T>(*folded),
180 Constant<LogicalResult>{
181 std::move(maskElement), ConstantSubscripts{folded->shape()}}};
182}
183
184// Generalized reduction to an array of one dimension fewer (w/ DIM=)
185// or to a scalar (w/o DIM=). The ACCUMULATOR type must define
186// operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
187// and Done(Scalar<T> &).
188template <typename T, typename ACCUMULATOR, typename ARRAY>
189static Constant<T> DoReduction(const Constant<ARRAY> &array,
190 const Constant<LogicalResult> &mask, std::optional<int> &dim,
191 const Scalar<T> &identity, ACCUMULATOR &accumulator) {
192 ConstantSubscripts at{array.lbounds()};
193 ConstantSubscripts maskAt{mask.lbounds()};
194 std::vector<typename Constant<T>::Element> elements;
195 ConstantSubscripts resultShape; // empty -> scalar
196 if (dim) { // DIM= is present, so result is an array
197 resultShape = array.shape();
198 resultShape.erase(resultShape.begin() + (*dim - 1));
199 ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
200 CHECK(dimExtent == mask.shape().at(*dim - 1));
201 ConstantSubscript &dimAt{at[*dim - 1]};
202 ConstantSubscript dimLbound{dimAt};
203 ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
204 ConstantSubscript maskDimLbound{maskDimAt};
205 for (auto n{GetSize(resultShape)}; n-- > 0;
206 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
207 elements.push_back(identity);
208 if (dimExtent > 0) {
209 dimAt = dimLbound;
210 maskDimAt = maskDimLbound;
211 bool firstUnmasked{true};
212 for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
213 if (mask.At(maskAt).IsTrue()) {
214 accumulator(elements.back(), at, firstUnmasked);
215 firstUnmasked = false;
216 }
217 }
218 --dimAt, --maskDimAt;
219 }
220 accumulator.Done(elements.back());
221 }
222 } else { // no DIM=, result is scalar
223 elements.push_back(identity);
224 bool firstUnmasked{true};
225 for (auto n{array.size()}; n-- > 0;
226 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
227 if (mask.At(maskAt).IsTrue()) {
228 accumulator(elements.back(), at, firstUnmasked);
229 firstUnmasked = false;
230 }
231 }
232 accumulator.Done(elements.back());
233 }
234 if constexpr (T::category == TypeCategory::Character) {
235 return {static_cast<ConstantSubscript>(identity.size()),
236 std::move(elements), std::move(resultShape)};
237 } else {
238 return {std::move(elements), std::move(resultShape)};
239 }
240}
241
242// MAXVAL & MINVAL
243template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
244public:
246 RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
247 : opr_{opr}, context_{context}, array_{array} {};
248 void operator()(Scalar<T> &element, const ConstantSubscripts &at,
249 [[maybe_unused]] bool firstUnmasked) const {
250 auto aAt{array_.At(at)};
251 if constexpr (ABS) {
252 aAt = aAt.ABS();
253 }
254 if constexpr (T::category == TypeCategory::Real) {
255 if (firstUnmasked || element.IsNotANumber()) {
256 // Return NaN if and only if all unmasked elements are NaNs and
257 // at least one unmasked element is visible.
258 element = aAt;
259 return;
260 }
261 }
262 Expr<LogicalResult> test{PackageRelation(
263 opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
264 auto folded{GetScalarConstantValue<LogicalResult>(
265 test.Rewrite(context_, std::move(test)))};
266 CHECK(folded.has_value());
267 if (folded->IsTrue()) {
268 element = aAt;
269 }
270 }
271 void Done(Scalar<T> &) const {}
272
273private:
274 RelationalOperator opr_;
275 FoldingContext &context_;
276 const Constant<T> &array_;
277};
278
279template <typename T>
280static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
281 RelationalOperator opr, const Scalar<T> &identity) {
282 static_assert(T::category == TypeCategory::Integer ||
283 T::category == TypeCategory::Unsigned ||
284 T::category == TypeCategory::Real ||
285 T::category == TypeCategory::Character);
286 std::optional<int> dim;
287 if (std::optional<ArrayAndMask<T>> arrayAndMask{
288 ProcessReductionArgs<T>(context, ref.arguments(), dim,
289 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
290 MaxvalMinvalAccumulator<T> accumulator{opr, context, arrayAndMask->array};
291 return Expr<T>{DoReduction<T>(
292 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
293 }
294 return Expr<T>{std::move(ref)};
295}
296
297// PRODUCT
298template <typename T> class ProductAccumulator {
299public:
300 ProductAccumulator(const Constant<T> &array) : array_{array} {}
301 void operator()(
302 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
303 if constexpr (T::category == TypeCategory::Integer) {
304 auto prod{element.MultiplySigned(array_.At(at))};
305 overflow_ |= prod.SignedMultiplicationOverflowed();
306 element = prod.lower;
307 } else if constexpr (T::category == TypeCategory::Unsigned) {
308 element = element.MultiplyUnsigned(array_.At(at)).lower;
309 } else { // Real & Complex
310 auto prod{element.Multiply(array_.At(at))};
311 overflow_ |= prod.flags.test(RealFlag::Overflow);
312 element = prod.value;
313 }
314 }
315 bool overflow() const { return overflow_; }
316 void Done(Scalar<T> &) const {}
317
318private:
319 const Constant<T> &array_;
320 bool overflow_{false};
321};
322
323template <typename T>
324static Expr<T> FoldProduct(
325 FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
326 static_assert(T::category == TypeCategory::Integer ||
327 T::category == TypeCategory::Unsigned ||
328 T::category == TypeCategory::Real ||
329 T::category == TypeCategory::Complex);
330 std::optional<int> dim;
331 if (std::optional<ArrayAndMask<T>> arrayAndMask{
332 ProcessReductionArgs<T>(context, ref.arguments(), dim,
333 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
334 ProductAccumulator accumulator{arrayAndMask->array};
335 auto result{Expr<T>{DoReduction<T>(
336 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
337 if (accumulator.overflow() &&
338 context.languageFeatures().ShouldWarn(
339 common::UsageWarning::FoldingException)) {
340 context.messages().Say(common::UsageWarning::FoldingException,
341 "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
342 }
343 return result;
344 }
345 return Expr<T>{std::move(ref)};
346}
347
348// SUM
349template <typename T> class SumAccumulator {
350 using Element = typename Constant<T>::Element;
351
352public:
353 SumAccumulator(const Constant<T> &array, Rounding rounding)
354 : array_{array}, rounding_{rounding} {}
355 void operator()(
356 Element &element, const ConstantSubscripts &at, bool /*first*/) {
357 if constexpr (T::category == TypeCategory::Integer) {
358 auto sum{element.AddSigned(array_.At(at))};
359 overflow_ |= sum.overflow;
360 element = sum.value;
361 } else if constexpr (T::category == TypeCategory::Unsigned) {
362 element = element.AddUnsigned(array_.At(at)).value;
363 } else { // Real & Complex: use Kahan summation
364 auto next{array_.At(at).Subtract(correction_, rounding_)};
365 overflow_ |= next.flags.test(RealFlag::Overflow);
366 auto sum{element.Add(next.value, rounding_)};
367 overflow_ |= sum.flags.test(RealFlag::Overflow);
368 // correction = (sum - element) - next; algebraically zero
369 correction_ = sum.value.Subtract(element, rounding_)
370 .value.Subtract(next.value, rounding_)
371 .value;
372 element = sum.value;
373 }
374 }
375 bool overflow() const { return overflow_; }
376 void Done([[maybe_unused]] Element &element) {
377 if constexpr (T::category != TypeCategory::Integer &&
378 T::category != TypeCategory::Unsigned) {
379 auto corrected{element.Add(correction_, rounding_)};
380 overflow_ |= corrected.flags.test(RealFlag::Overflow);
381 correction_ = Scalar<T>{};
382 element = corrected.value;
383 }
384 }
385
386private:
387 const Constant<T> &array_;
388 Rounding rounding_;
389 bool overflow_{false};
390 Element correction_{};
391};
392
393template <typename T>
394static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
395 static_assert(T::category == TypeCategory::Integer ||
396 T::category == TypeCategory::Unsigned ||
397 T::category == TypeCategory::Real ||
398 T::category == TypeCategory::Complex);
399 using Element = typename Constant<T>::Element;
400 std::optional<int> dim;
401 Element identity{};
402 if (std::optional<ArrayAndMask<T>> arrayAndMask{
403 ProcessReductionArgs<T>(context, ref.arguments(), dim,
404 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
405 SumAccumulator accumulator{
406 arrayAndMask->array, context.targetCharacteristics().roundingMode()};
407 auto result{Expr<T>{DoReduction<T>(
408 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
409 if (accumulator.overflow() &&
410 context.languageFeatures().ShouldWarn(
411 common::UsageWarning::FoldingException)) {
412 context.messages().Say(common::UsageWarning::FoldingException,
413 "SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
414 }
415 return result;
416 }
417 return Expr<T>{std::move(ref)};
418}
419
420// Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
421template <typename T> class OperationAccumulator {
422public:
424 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
425 : array_{array}, operation_{operation} {}
426 void operator()(
427 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
428 element = (element.*operation_)(array_.At(at));
429 }
430 void Done(Scalar<T> &) const {}
431
432private:
433 const Constant<T> &array_;
434 Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
435};
436
437} // namespace Fortran::evaluate
438#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
Definition: constant.h:141
Definition: common.h:213
Definition: common.h:215
Definition: call.h:282
Definition: fold-reduction.h:243
Definition: fold-reduction.h:421
Definition: fold-reduction.h:298
Definition: fold-reduction.h:349
Definition: call.h:34
Definition: target-rounding.h:18
Definition: fold-reduction.h:142