FLANG
ReductionProcessor.h
1//===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
14#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
15
16#include "flang/Lower/OpenMP/Clauses.h"
17#include "flang/Optimizer/Builder/FIRBuilder.h"
18#include "flang/Optimizer/Dialect/FIRType.h"
19#include "flang/Parser/parse-tree.h"
20#include "flang/Semantics/symbol.h"
21#include "flang/Semantics/type.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/Types.h"
24
25namespace mlir {
26namespace omp {
27class DeclareReductionOp;
28} // namespace omp
29} // namespace mlir
30
31namespace Fortran {
32namespace lower {
34} // namespace lower
35} // namespace Fortran
36
37namespace Fortran {
38namespace lower {
39namespace omp {
40
42public:
43 using GenInitValueCBTy =
44 std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
45 mlir::Type type, mlir::Value ompOrig)>;
46 using GenCombinerCBTy = std::function<void(
47 fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
48 mlir::Value op1, mlir::Value op2, bool isByRef)>;
49
50 // TODO: Move this enumeration to the OpenMP dialect
51 enum ReductionIdentifier {
52 ID,
53 USER_DEF_OP,
54 ADD,
55 SUBTRACT,
56 MULTIPLY,
57 AND,
58 OR,
59 EQV,
60 NEQV,
61 MAX,
62 MIN,
63 IAND,
64 IOR,
65 IEOR
66 };
67
68 static bool doReductionByRef(mlir::Type reductionType);
69 static bool doReductionByRef(mlir::Value reductionVar);
70
71 static ReductionIdentifier
72 getReductionType(const omp::clause::ProcedureDesignator &pd);
73
74 static ReductionIdentifier
75 getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
76
77 static ReductionIdentifier
78 getReductionType(const fir::ReduceOperationEnum &pd);
79
80 static bool
81 supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
82
83 static const semantics::SourceName
84 getRealName(const semantics::Symbol *symbol);
85
86 static const semantics::SourceName
87 getRealName(const omp::clause::ProcedureDesignator &pd);
88
89 static std::string getReductionName(llvm::StringRef name,
90 const fir::KindMapping &kindMap,
91 mlir::Type ty, bool isByRef);
92
93 static std::string getReductionName(ReductionIdentifier redId,
94 const fir::KindMapping &kindMap,
95 mlir::Type ty, bool isByRef);
96
101 static int getOperationIdentity(ReductionIdentifier redId,
102 mlir::Location loc);
103
104 static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
105 ReductionIdentifier redId,
106 fir::FirOpBuilder &builder);
107
108 template <typename FloatOp, typename IntegerOp>
109 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
110 mlir::Type type, mlir::Location loc,
111 mlir::Value op1, mlir::Value op2);
112 template <typename FloatOp, typename IntegerOp, typename ComplexOp>
113 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
114 mlir::Type type, mlir::Location loc,
115 mlir::Value op1, mlir::Value op2);
116
117 static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
118 mlir::Location loc,
119 ReductionIdentifier redId,
120 mlir::Type type, mlir::Value op1,
121 mlir::Value op2);
125 template <typename DeclareRedType>
126 static DeclareRedType createDeclareReductionHelper(
127 AbstractConverter &converter, llvm::StringRef reductionOpName,
128 mlir::Type type, mlir::Location loc, bool isByRef,
129 GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
130
135 template <typename OpType>
136 static OpType createDeclareReduction(AbstractConverter &builder,
137 llvm::StringRef reductionOpName,
138 const ReductionIdentifier redId,
139 mlir::Type type, mlir::Location loc,
140 bool isByRef);
141
144 template <typename OpType, typename RedOperatorListTy>
145 static bool processReductionArguments(
146 mlir::Location currentLocation, lower::AbstractConverter &converter,
147 const RedOperatorListTy &redOperatorList,
148 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
149 llvm::SmallVectorImpl<bool> &reduceVarByRef,
150 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
151 const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
152};
153
154template <typename FloatOp, typename IntegerOp>
155mlir::Value
156ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
157 mlir::Type type, mlir::Location loc,
158 mlir::Value op1, mlir::Value op2) {
159 type = fir::unwrapRefType(type);
160 assert(type.isIntOrIndexOrFloat() &&
161 "only integer, float and complex types are currently supported");
162 if (type.isIntOrIndex())
163 return IntegerOp::create(builder, loc, op1, op2);
164 return FloatOp::create(builder, loc, op1, op2);
165}
166
167template <typename FloatOp, typename IntegerOp, typename ComplexOp>
168mlir::Value
169ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
170 mlir::Type type, mlir::Location loc,
171 mlir::Value op1, mlir::Value op2) {
172 assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) &&
173 "only integer, float and complex types are currently supported");
174 if (type.isIntOrIndex())
175 return IntegerOp::create(builder, loc, op1, op2);
176 if (fir::isa_real(type))
177 return FloatOp::create(builder, loc, op1, op2);
178 return ComplexOp::create(builder, loc, op1, op2);
179}
180
181} // namespace omp
182} // namespace lower
183} // namespace Fortran
184
185#endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H
Definition AbstractConverter.h:85
Definition ReductionProcessor.h:41
static int getOperationIdentity(ReductionIdentifier redId, mlir::Location loc)
Definition ReductionProcessor.cpp:865
static OpType createDeclareReduction(AbstractConverter &builder, llvm::StringRef reductionOpName, const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, bool isByRef)
Definition ReductionProcessor.cpp:635
static bool processReductionArguments(mlir::Location currentLocation, lower::AbstractConverter &converter, const RedOperatorListTy &redOperatorList, llvm::SmallVectorImpl< mlir::Value > &reductionVars, llvm::SmallVectorImpl< bool > &reduceVarByRef, llvm::SmallVectorImpl< mlir::Attribute > &reductionDeclSymbols, const llvm::SmallVectorImpl< const semantics::Symbol * > &reductionSymbols)
Definition ReductionProcessor.cpp:680
static DeclareRedType createDeclareReductionHelper(AbstractConverter &converter, llvm::StringRef reductionOpName, mlir::Type type, mlir::Location loc, bool isByRef, GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB)
Definition ReductionProcessor.cpp:559
Definition symbol.h:809
Definition FIRBuilder.h:55
Definition KindMapping.h:48
Definition ParserActions.h:24
Definition bit-population-count.h:20
bool isa_complex(mlir::Type t)
Is t a floating point complex type?
Definition FIRType.h:206
bool isa_real(mlir::Type t)
Is t a real type?
Definition FIRType.h:185
Definition AbstractConverter.h:29