FLANG
fold-reduction.h
1//===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
10#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
11
12#include "fold-implementation.h"
13
14namespace Fortran::evaluate {
15
16// DOT_PRODUCT
17template <typename T>
18static Expr<T> FoldDotProduct(
19 FoldingContext &context, FunctionRef<T> &&funcRef) {
20 using Element = typename Constant<T>::Element;
21 auto args{funcRef.arguments()};
22 CHECK(args.size() == 2);
23 Folder<T> folder{context};
24 Constant<T> *va{folder.Folding(args[0])};
25 Constant<T> *vb{folder.Folding(args[1])};
26 if (va && vb) {
27 CHECK(va->Rank() == 1 && vb->Rank() == 1);
28 if (va->size() != vb->size()) {
29 context.messages().Say(
30 "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
31 va->size(), vb->size());
32 return MakeInvalidIntrinsic(std::move(funcRef));
33 }
34 Element sum{};
35 bool overflow{false};
36 if constexpr (T::category == TypeCategory::Complex) {
37 std::vector<Element> conjugates;
38 for (const Element &x : va->values()) {
39 conjugates.emplace_back(x.CONJG());
40 }
41 Constant<T> conjgA{
42 std::move(conjugates), ConstantSubscripts{va->shape()}};
43 Expr<T> products{Fold(
44 context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
45 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
46 [[maybe_unused]] Element correction{};
47 const auto &rounding{context.targetCharacteristics().roundingMode()};
48 for (const Element &x : cProducts.values()) {
49 if constexpr (useKahanSummation) {
50 auto added{sum.KahanSummation(x, correction, rounding)};
51 overflow |= added.flags.test(RealFlag::Overflow);
52 sum = added.value;
53 } else {
54 auto added{sum.Add(x, rounding)};
55 overflow |= added.flags.test(RealFlag::Overflow);
56 sum = added.value;
57 }
58 }
59 } else if constexpr (T::category == TypeCategory::Logical) {
60 Expr<T> conjunctions{Fold(context,
61 Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
62 Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
63 Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
64 for (const Element &x : cConjunctions.values()) {
65 if (x.IsTrue()) {
66 sum = Element{true};
67 break;
68 }
69 }
70 } else if constexpr (T::category == TypeCategory::Integer) {
71 Expr<T> products{
72 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
73 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
74 for (const Element &x : cProducts.values()) {
75 auto next{sum.AddSigned(x)};
76 overflow |= next.overflow;
77 sum = std::move(next.value);
78 }
79 } else if constexpr (T::category == TypeCategory::Unsigned) {
80 Expr<T> products{
81 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
82 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
83 for (const Element &x : cProducts.values()) {
84 sum = sum.AddUnsigned(x).value;
85 }
86 } else {
87 static_assert(T::category == TypeCategory::Real);
88 Expr<T> products{
89 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
90 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
91 [[maybe_unused]] Element correction{};
92 const auto &rounding{context.targetCharacteristics().roundingMode()};
93 for (const Element &x : cProducts.values()) {
94 if constexpr (useKahanSummation) {
95 auto added{sum.KahanSummation(x, correction, rounding)};
96 overflow |= added.flags.test(RealFlag::Overflow);
97 sum = added.value;
98 } else {
99 auto added{sum.Add(x, rounding)};
100 overflow |= added.flags.test(RealFlag::Overflow);
101 sum = added.value;
102 }
103 }
104 }
105 if (overflow) {
106 context.Warn(common::UsageWarning::FoldingException,
107 "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
108 T::AsFortran());
109 }
110 return Expr<T>{Constant<T>{std::move(sum)}};
111 }
112 return Expr<T>{std::move(funcRef)};
113}
114
115// Fold and validate a DIM= argument. Returns false on error.
116bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
117 ActualArguments &, std::optional<int> dimIndex, int rank);
118
119// Fold and validate a MASK= argument. Return null on error, absent MASK=, or
120// non-constant MASK=.
121Constant<LogicalResult> *GetReductionMASK(
122 std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
124
125// Common preprocessing for reduction transformational intrinsic function
126// folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
127// and check them. If a MASK= is present, apply it to the array data and
128// substitute replacement values for elements corresponding to .FALSE. in
129// the mask. If the result is present, the intrinsic call can be folded.
130template <typename T> struct ArrayAndMask {
131 Constant<T> array;
133};
134template <typename T>
135static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
136 FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
137 int arrayIndex, std::optional<int> dimIndex = std::nullopt,
138 std::optional<int> maskIndex = std::nullopt) {
139 if (arg.empty()) {
140 return std::nullopt;
141 }
142 Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
143 if (!folded || folded->Rank() < 1) {
144 return std::nullopt;
145 }
146 if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
147 return std::nullopt;
148 }
149 std::size_t n{folded->size()};
150 std::vector<Scalar<LogicalResult>> maskElement;
151 if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
152 arg[*maskIndex]) {
153 if (const Constant<LogicalResult> *origMask{
154 GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
155 if (auto scalarMask{origMask->GetScalarValue()}) {
156 maskElement =
157 std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
158 } else {
159 maskElement = origMask->values();
160 }
161 } else {
162 return std::nullopt;
163 }
164 } else {
165 maskElement = std::vector<Scalar<LogicalResult>>(n, true);
166 }
167 return ArrayAndMask<T>{Constant<T>(*folded),
169 std::move(maskElement), ConstantSubscripts{folded->shape()}}};
170}
171
172// Generalized reduction to an array of one dimension fewer (w/ DIM=)
173// or to a scalar (w/o DIM=). The ACCUMULATOR type must define
174// operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
175// and Done(Scalar<T> &).
176template <typename T, typename ACCUMULATOR, typename ARRAY>
177static Constant<T> DoReduction(const Constant<ARRAY> &array,
178 const Constant<LogicalResult> &mask, std::optional<int> &dim,
179 const Scalar<T> &identity, ACCUMULATOR &accumulator) {
180 ConstantSubscripts at{array.lbounds()};
181 ConstantSubscripts maskAt{mask.lbounds()};
182 std::vector<typename Constant<T>::Element> elements;
183 ConstantSubscripts resultShape; // empty -> scalar
184 if (dim) { // DIM= is present, so result is an array
185 resultShape = array.shape();
186 resultShape.erase(resultShape.begin() + (*dim - 1));
187 ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
188 CHECK(dimExtent == mask.shape().at(*dim - 1));
189 ConstantSubscript &dimAt{at[*dim - 1]};
190 ConstantSubscript dimLbound{dimAt};
191 ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
192 ConstantSubscript maskDimLbound{maskDimAt};
193 for (auto n{GetSize(resultShape)}; n-- > 0;
194 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
195 elements.push_back(identity);
196 if (dimExtent > 0) {
197 dimAt = dimLbound;
198 maskDimAt = maskDimLbound;
199 bool firstUnmasked{true};
200 for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
201 if (mask.At(maskAt).IsTrue()) {
202 accumulator(elements.back(), at, firstUnmasked);
203 firstUnmasked = false;
204 }
205 }
206 --dimAt, --maskDimAt;
207 }
208 accumulator.Done(elements.back());
209 }
210 } else { // no DIM=, result is scalar
211 elements.push_back(identity);
212 bool firstUnmasked{true};
213 for (auto n{array.size()}; n-- > 0;
214 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
215 if (mask.At(maskAt).IsTrue()) {
216 accumulator(elements.back(), at, firstUnmasked);
217 firstUnmasked = false;
218 }
219 }
220 accumulator.Done(elements.back());
221 }
222 if constexpr (T::category == TypeCategory::Character) {
223 return {static_cast<ConstantSubscript>(identity.size()),
224 std::move(elements), std::move(resultShape)};
225 } else {
226 return {std::move(elements), std::move(resultShape)};
227 }
228}
229
230// MAXVAL & MINVAL
231template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
232public:
233 MaxvalMinvalAccumulator(
234 RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
235 : opr_{opr}, context_{context}, array_{array} {};
236 void operator()(Scalar<T> &element, const ConstantSubscripts &at,
237 [[maybe_unused]] bool firstUnmasked) const {
238 auto aAt{array_.At(at)};
239 if constexpr (ABS) {
240 aAt = aAt.ABS();
241 }
242 if constexpr (T::category == TypeCategory::Real) {
243 if (firstUnmasked || element.IsNotANumber()) {
244 // Return NaN if and only if all unmasked elements are NaNs and
245 // at least one unmasked element is visible.
246 element = aAt;
247 return;
248 }
249 }
250 Expr<LogicalResult> test{PackageRelation(
251 opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
252 auto folded{GetScalarConstantValue<LogicalResult>(
253 test.Rewrite(context_, std::move(test)))};
254 CHECK(folded.has_value());
255 if (folded->IsTrue()) {
256 element = aAt;
257 }
258 }
259 void Done(Scalar<T> &) const {}
260
261private:
262 RelationalOperator opr_;
263 FoldingContext &context_;
264 const Constant<T> &array_;
265};
266
267template <typename T>
268static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
269 RelationalOperator opr, const Scalar<T> &identity) {
270 static_assert(T::category == TypeCategory::Integer ||
271 T::category == TypeCategory::Unsigned ||
272 T::category == TypeCategory::Real ||
273 T::category == TypeCategory::Character);
274 std::optional<int> dim;
275 if (std::optional<ArrayAndMask<T>> arrayAndMask{
276 ProcessReductionArgs<T>(context, ref.arguments(), dim,
277 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
278 MaxvalMinvalAccumulator<T> accumulator{opr, context, arrayAndMask->array};
279 return Expr<T>{DoReduction<T>(
280 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
281 }
282 return Expr<T>{std::move(ref)};
283}
284
285// PRODUCT
286template <typename T> class ProductAccumulator {
287public:
288 ProductAccumulator(const Constant<T> &array) : array_{array} {}
289 void operator()(
290 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
291 if constexpr (T::category == TypeCategory::Integer) {
292 auto prod{element.MultiplySigned(array_.At(at))};
293 overflow_ |= prod.SignedMultiplicationOverflowed();
294 element = prod.lower;
295 } else if constexpr (T::category == TypeCategory::Unsigned) {
296 element = element.MultiplyUnsigned(array_.At(at)).lower;
297 } else { // Real & Complex
298 auto prod{element.Multiply(array_.At(at))};
299 overflow_ |= prod.flags.test(RealFlag::Overflow);
300 element = prod.value;
301 }
302 }
303 bool overflow() const { return overflow_; }
304 void Done(Scalar<T> &) const {}
305
306private:
307 const Constant<T> &array_;
308 bool overflow_{false};
309};
310
311template <typename T>
312static Expr<T> FoldProduct(
313 FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
314 static_assert(T::category == TypeCategory::Integer ||
315 T::category == TypeCategory::Unsigned ||
316 T::category == TypeCategory::Real ||
317 T::category == TypeCategory::Complex);
318 std::optional<int> dim;
319 if (std::optional<ArrayAndMask<T>> arrayAndMask{
320 ProcessReductionArgs<T>(context, ref.arguments(), dim,
321 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
322 ProductAccumulator accumulator{arrayAndMask->array};
323 auto result{Expr<T>{DoReduction<T>(
324 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
325 if (accumulator.overflow()) {
326 context.Warn(common::UsageWarning::FoldingException,
327 "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
328 }
329 return result;
330 }
331 return Expr<T>{std::move(ref)};
332}
333
334// SUM
335template <typename T> class SumAccumulator {
336 using Element = typename Constant<T>::Element;
337
338public:
339 SumAccumulator(const Constant<T> &array, Rounding rounding)
340 : array_{array}, rounding_{rounding} {}
341 void operator()(
342 Element &element, const ConstantSubscripts &at, bool /*first*/) {
343 if constexpr (T::category == TypeCategory::Integer) {
344 auto sum{element.AddSigned(array_.At(at))};
345 overflow_ |= sum.overflow;
346 element = sum.value;
347 } else if constexpr (T::category == TypeCategory::Unsigned) {
348 element = element.AddUnsigned(array_.At(at)).value;
349 } else { // Real & Complex: use Kahan summation
350 auto sum{element.KahanSummation(array_.At(at), correction_, rounding_)};
351 overflow_ |= sum.flags.test(RealFlag::Overflow);
352 element = sum.value;
353 }
354 }
355 bool overflow() const { return overflow_; }
356 void Done([[maybe_unused]] Element &element) {
357 if constexpr (T::category != TypeCategory::Integer &&
358 T::category != TypeCategory::Unsigned) {
359 auto corrected{element.Add(correction_, rounding_)};
360 overflow_ |= corrected.flags.test(RealFlag::Overflow);
361 correction_ = Scalar<T>{};
362 element = corrected.value;
363 }
364 }
365
366private:
367 const Constant<T> &array_;
368 Rounding rounding_;
369 bool overflow_{false};
370 Element correction_{};
371};
372
373template <typename T>
374static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
375 static_assert(T::category == TypeCategory::Integer ||
376 T::category == TypeCategory::Unsigned ||
377 T::category == TypeCategory::Real ||
378 T::category == TypeCategory::Complex);
379 using Element = typename Constant<T>::Element;
380 std::optional<int> dim;
381 Element identity{};
382 if (std::optional<ArrayAndMask<T>> arrayAndMask{
383 ProcessReductionArgs<T>(context, ref.arguments(), dim,
384 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
385 SumAccumulator accumulator{
386 arrayAndMask->array, context.targetCharacteristics().roundingMode()};
387 auto result{Expr<T>{DoReduction<T>(
388 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
389 if (accumulator.overflow()) {
390 context.Warn(common::UsageWarning::FoldingException,
391 "SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
392 }
393 return result;
394 }
395 return Expr<T>{std::move(ref)};
396}
397
398// Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
399template <typename T> class OperationAccumulator {
400public:
401 OperationAccumulator(const Constant<T> &array,
402 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
403 : array_{array}, operation_{operation} {}
404 void operator()(
405 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
406 element = (element.*operation_)(array_.At(at));
407 }
408 void Done(Scalar<T> &) const {}
409
410private:
411 const Constant<T> &array_;
412 Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
413};
414
415} // namespace Fortran::evaluate
416#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
Definition constant.h:147
Definition common.h:214
Definition fold-implementation.h:55
Definition common.h:216
Definition call.h:293
Definition call.h:34
Definition target-rounding.h:18
Definition fold-reduction.h:130
Definition expression.h:379