9#ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
10#define FORTRAN_EVALUATE_FOLD_MATMUL_H_
12#include "fold-implementation.h"
18 using Element =
typename Constant<T>::Element;
19 auto args{funcRef.arguments()};
20 CHECK(args.size() == 2);
25 return Expr<T>{std::move(funcRef)};
27 CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 &&
28 mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2));
29 ConstantSubscript commonExtent{ma->shape().back()};
30 if (mb->shape().front() != commonExtent) {
31 context.messages().Say(
32 "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US,
33 commonExtent, mb->shape().front());
34 return MakeInvalidIntrinsic(std::move(funcRef));
36 ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]};
37 ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]};
38 std::vector<Element> elements;
39 elements.reserve(rows * columns);
41 [[maybe_unused]]
const auto &rounding{
42 context.targetCharacteristics().roundingMode()};
44 for (ConstantSubscript ci{0}; ci < columns; ++ci) {
45 for (ConstantSubscript ri{0}; ri < rows; ++ri) {
46 ConstantSubscripts aAt{ma->lbounds()};
47 if (ma->Rank() == 2) {
50 ConstantSubscripts bAt{mb->lbounds()};
51 if (mb->Rank() == 2) {
55 [[maybe_unused]] Element correction{};
56 for (ConstantSubscript j{0}; j < commonExtent; ++j) {
57 Element aElt{ma->At(aAt)};
58 Element bElt{mb->At(bAt)};
59 if constexpr (T::category == TypeCategory::Real ||
60 T::category == TypeCategory::Complex) {
61 auto product{aElt.Multiply(bElt)};
62 overflow |= product.flags.test(RealFlag::Overflow);
63 if constexpr (useKahanSummation) {
64 auto added{sum.KahanSummation(product.value, correction)};
65 overflow |= added.flags.test(RealFlag::Overflow);
68 auto added{sum.Add(product.value)};
69 overflow |= added.flags.test(RealFlag::Overflow);
72 }
else if constexpr (T::category == TypeCategory::Integer) {
73 auto product{aElt.MultiplySigned(bElt)};
74 overflow |= product.SignedMultiplicationOverflowed();
75 auto added{sum.AddSigned(product.lower)};
76 overflow |= added.overflow;
77 sum = std::move(added.value);
78 }
else if constexpr (T::category == TypeCategory::Unsigned) {
79 sum = sum.AddUnsigned(aElt.MultiplyUnsigned(bElt).lower).value;
81 static_assert(T::category == TypeCategory::Logical);
82 sum = sum.OR(aElt.AND(bElt));
87 elements.push_back(sum);
91 context.Warn(common::UsageWarning::FoldingException,
92 "MATMUL of %s data overflowed during computation"_warn_en_US,
95 ConstantSubscripts shape;
96 if (ma->Rank() == 2) {
97 shape.push_back(rows);
99 if (mb->Rank() == 2) {
100 shape.push_back(columns);
Definition constant.h:147
Definition fold-implementation.h:55