FLANG
ReductionProcessor.h
1//===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
14#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
15
16#include "flang/Lower/OpenMP/Clauses.h"
17#include "flang/Optimizer/Builder/FIRBuilder.h"
18#include "flang/Optimizer/Dialect/FIRType.h"
19#include "flang/Parser/parse-tree.h"
20#include "flang/Semantics/symbol.h"
21#include "flang/Semantics/type.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/Types.h"
24
25namespace mlir {
26namespace omp {
27class DeclareReductionOp;
28} // namespace omp
29} // namespace mlir
30
31namespace Fortran {
32namespace lower {
34} // namespace lower
35} // namespace Fortran
36
37namespace Fortran {
38namespace lower {
39namespace omp {
40
42public:
43 // ompOrig: mold/original variable
44 // ompPriv: private allocation (may be null for by-value reductions)
45 using GenInitValueCBTy = std::function<mlir::Value(
46 fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
47 mlir::Value ompOrig, mlir::Value ompPriv)>;
48 using GenCombinerCBTy = std::function<void(
49 fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
50 mlir::Value op1, mlir::Value op2, bool isByRef)>;
51
52 // TODO: Move this enumeration to the OpenMP dialect
53 enum ReductionIdentifier {
54 ID,
55 USER_DEF_OP,
56 ADD,
57 SUBTRACT,
58 MULTIPLY,
59 AND,
60 OR,
61 EQV,
62 NEQV,
63 MAX,
64 MIN,
65 IAND,
66 IOR,
67 IEOR
68 };
69
70 static bool doReductionByRef(mlir::Type reductionType);
71 static bool doReductionByRef(mlir::Value reductionVar);
72
73 static ReductionIdentifier
74 getReductionType(const omp::clause::ProcedureDesignator &pd);
75
76 static ReductionIdentifier
77 getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
78
79 static ReductionIdentifier
80 getReductionType(const fir::ReduceOperationEnum &pd);
81
82 static bool
83 supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
84
85 static const semantics::SourceName
86 getRealName(const semantics::Symbol *symbol);
87
88 static const semantics::SourceName
89 getRealName(const omp::clause::ProcedureDesignator &pd);
90
91 static std::string getReductionName(llvm::StringRef name,
92 const fir::KindMapping &kindMap,
93 mlir::Type ty, bool isByRef);
94
95 static std::string getReductionName(ReductionIdentifier redId,
96 const fir::KindMapping &kindMap,
97 mlir::Type ty, bool isByRef);
98
103 static int getOperationIdentity(ReductionIdentifier redId,
104 mlir::Location loc);
105
106 static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
107 ReductionIdentifier redId,
108 fir::FirOpBuilder &builder);
109
110 template <typename FloatOp, typename IntegerOp>
111 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
112 mlir::Type type, mlir::Location loc,
113 mlir::Value op1, mlir::Value op2);
114 template <typename FloatOp, typename IntegerOp, typename ComplexOp>
115 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
116 mlir::Type type, mlir::Location loc,
117 mlir::Value op1, mlir::Value op2);
118
119 static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
120 mlir::Location loc,
121 ReductionIdentifier redId,
122 mlir::Type type, mlir::Value op1,
123 mlir::Value op2);
127 template <typename DeclareRedType>
128 static DeclareRedType createDeclareReductionHelper(
129 AbstractConverter &converter, llvm::StringRef reductionOpName,
130 mlir::Type type, mlir::Location loc, bool isByRef,
131 GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB,
132 const semantics::Symbol *sym = nullptr);
133
138 template <typename OpType>
139 static OpType createDeclareReduction(AbstractConverter &builder,
140 llvm::StringRef reductionOpName,
141 const ReductionIdentifier redId,
142 mlir::Type type, mlir::Location loc,
143 bool isByRef);
144
153 template <typename OpType, typename RedOperatorListTy>
154 static bool processReductionArguments(
155 mlir::Location currentLocation, lower::AbstractConverter &converter,
156 const RedOperatorListTy &redOperatorList,
157 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
158 llvm::SmallVectorImpl<bool> &reduceVarByRef,
159 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
160 const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
161 llvm::DenseMap<const semantics::Symbol *, mlir::Value>
162 *reductionVarCache = nullptr);
163};
164
165template <typename FloatOp, typename IntegerOp>
166mlir::Value
167ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
168 mlir::Type type, mlir::Location loc,
169 mlir::Value op1, mlir::Value op2) {
170 type = fir::unwrapRefType(type);
171 assert(type.isIntOrIndexOrFloat() &&
172 "only integer, float and complex types are currently supported");
173 if (type.isIntOrIndex())
174 return IntegerOp::create(builder, loc, op1, op2);
175 return FloatOp::create(builder, loc, op1, op2);
176}
177
178template <typename FloatOp, typename IntegerOp, typename ComplexOp>
179mlir::Value
180ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
181 mlir::Type type, mlir::Location loc,
182 mlir::Value op1, mlir::Value op2) {
183 assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) &&
184 "only integer, float and complex types are currently supported");
185 if (type.isIntOrIndex())
186 return IntegerOp::create(builder, loc, op1, op2);
187 if (fir::isa_real(type))
188 return FloatOp::create(builder, loc, op1, op2);
189 return ComplexOp::create(builder, loc, op1, op2);
190}
191
192} // namespace omp
193} // namespace lower
194} // namespace Fortran
195
196#endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H
Definition AbstractConverter.h:87
Definition ReductionProcessor.h:41
static bool processReductionArguments(mlir::Location currentLocation, lower::AbstractConverter &converter, const RedOperatorListTy &redOperatorList, llvm::SmallVectorImpl< mlir::Value > &reductionVars, llvm::SmallVectorImpl< bool > &reduceVarByRef, llvm::SmallVectorImpl< mlir::Attribute > &reductionDeclSymbols, const llvm::SmallVectorImpl< const semantics::Symbol * > &reductionSymbols, llvm::DenseMap< const semantics::Symbol *, mlir::Value > *reductionVarCache=nullptr)
Definition ReductionProcessor.cpp:657
static int getOperationIdentity(ReductionIdentifier redId, mlir::Location loc)
Definition ReductionProcessor.cpp:873
static OpType createDeclareReduction(AbstractConverter &builder, llvm::StringRef reductionOpName, const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, bool isByRef)
Definition ReductionProcessor.cpp:611
static DeclareRedType createDeclareReductionHelper(AbstractConverter &converter, llvm::StringRef reductionOpName, mlir::Type type, mlir::Location loc, bool isByRef, GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB, const semantics::Symbol *sym=nullptr)
Definition ReductionProcessor.cpp:549
Definition symbol.h:832
Definition FIRBuilder.h:56
Definition KindMapping.h:48
Definition ParserActions.h:24
Definition bit-population-count.h:20
bool isa_complex(mlir::Type t)
Is t a floating point complex type?
Definition FIRType.h:217
bool isa_real(mlir::Type t)
Is t a real type?
Definition FIRType.h:196
Definition AbstractConverter.h:32