FLANG
ReductionProcessor.h
1//===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
14#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
15
16#include "flang/Lower/OpenMP/Clauses.h"
17#include "flang/Optimizer/Builder/FIRBuilder.h"
18#include "flang/Optimizer/Dialect/FIRType.h"
19#include "flang/Parser/parse-tree.h"
20#include "flang/Semantics/symbol.h"
21#include "flang/Semantics/type.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/Types.h"
24
25namespace mlir {
26namespace omp {
27class DeclareReductionOp;
28} // namespace omp
29} // namespace mlir
30
31namespace Fortran {
32namespace lower {
34} // namespace lower
35} // namespace Fortran
36
37namespace Fortran {
38namespace lower {
39namespace omp {
40
42public:
43 // TODO: Move this enumeration to the OpenMP dialect
44 enum ReductionIdentifier {
45 ID,
46 USER_DEF_OP,
47 ADD,
48 SUBTRACT,
49 MULTIPLY,
50 AND,
51 OR,
52 EQV,
53 NEQV,
54 MAX,
55 MIN,
56 IAND,
57 IOR,
58 IEOR
59 };
60
61 static ReductionIdentifier
62 getReductionType(const omp::clause::ProcedureDesignator &pd);
63
64 static ReductionIdentifier
65 getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
66
67 static ReductionIdentifier
68 getReductionType(const fir::ReduceOperationEnum &pd);
69
70 static bool
71 supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
72
73 static const semantics::SourceName
74 getRealName(const semantics::Symbol *symbol);
75
76 static const semantics::SourceName
77 getRealName(const omp::clause::ProcedureDesignator &pd);
78
79 static std::string getReductionName(llvm::StringRef name,
80 const fir::KindMapping &kindMap,
81 mlir::Type ty, bool isByRef);
82
83 static std::string getReductionName(ReductionIdentifier redId,
84 const fir::KindMapping &kindMap,
85 mlir::Type ty, bool isByRef);
86
91 static int getOperationIdentity(ReductionIdentifier redId,
92 mlir::Location loc);
93
94 static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
95 ReductionIdentifier redId,
96 fir::FirOpBuilder &builder);
97
98 template <typename FloatOp, typename IntegerOp>
99 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
100 mlir::Type type, mlir::Location loc,
101 mlir::Value op1, mlir::Value op2);
102 template <typename FloatOp, typename IntegerOp, typename ComplexOp>
103 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
104 mlir::Type type, mlir::Location loc,
105 mlir::Value op1, mlir::Value op2);
106
107 static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
108 mlir::Location loc,
109 ReductionIdentifier redId,
110 mlir::Type type, mlir::Value op1,
111 mlir::Value op2);
112
117 template <typename OpType>
118 static OpType createDeclareReduction(AbstractConverter &builder,
119 llvm::StringRef reductionOpName,
120 const ReductionIdentifier redId,
121 mlir::Type type, mlir::Location loc,
122 bool isByRef);
123
126 template <typename OpType, typename RedOperatorListTy>
127 static bool processReductionArguments(
128 mlir::Location currentLocation, lower::AbstractConverter &converter,
129 const RedOperatorListTy &redOperatorList,
130 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
131 llvm::SmallVectorImpl<bool> &reduceVarByRef,
132 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
133 const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
134};
135
136template <typename FloatOp, typename IntegerOp>
137mlir::Value
138ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
139 mlir::Type type, mlir::Location loc,
140 mlir::Value op1, mlir::Value op2) {
141 type = fir::unwrapRefType(type);
142 assert(type.isIntOrIndexOrFloat() &&
143 "only integer, float and complex types are currently supported");
144 if (type.isIntOrIndex())
145 return IntegerOp::create(builder, loc, op1, op2);
146 return FloatOp::create(builder, loc, op1, op2);
147}
148
149template <typename FloatOp, typename IntegerOp, typename ComplexOp>
150mlir::Value
151ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
152 mlir::Type type, mlir::Location loc,
153 mlir::Value op1, mlir::Value op2) {
154 assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) &&
155 "only integer, float and complex types are currently supported");
156 if (type.isIntOrIndex())
157 return IntegerOp::create(builder, loc, op1, op2);
158 if (fir::isa_real(type))
159 return FloatOp::create(builder, loc, op1, op2);
160 return ComplexOp::create(builder, loc, op1, op2);
161}
162
163} // namespace omp
164} // namespace lower
165} // namespace Fortran
166
167#endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H
Definition AbstractConverter.h:85
Definition ReductionProcessor.h:41
static int getOperationIdentity(ReductionIdentifier redId, mlir::Location loc)
Definition ReductionProcessor.cpp:782
static OpType createDeclareReduction(AbstractConverter &builder, llvm::StringRef reductionOpName, const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, bool isByRef)
Definition ReductionProcessor.cpp:560
static bool processReductionArguments(mlir::Location currentLocation, lower::AbstractConverter &converter, const RedOperatorListTy &redOperatorList, llvm::SmallVectorImpl< mlir::Value > &reductionVars, llvm::SmallVectorImpl< bool > &reduceVarByRef, llvm::SmallVectorImpl< mlir::Attribute > &reductionDeclSymbols, const llvm::SmallVectorImpl< const semantics::Symbol * > &reductionSymbols)
Definition ReductionProcessor.cpp:610
Definition symbol.h:778
Definition FIRBuilder.h:55
Definition KindMapping.h:48
Definition ParserActions.h:24
Definition bit-population-count.h:20
bool isa_complex(mlir::Type t)
Is t a floating point complex type?
Definition FIRType.h:203
bool isa_real(mlir::Type t)
Is t a real type?
Definition FIRType.h:182
Definition AbstractConverter.h:29