FLANG
ClauseProcessor.h
1//===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12#ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
13#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
14
15#include "Clauses.h"
16#include "ReductionProcessor.h"
17#include "Utils.h"
18#include "flang/Lower/AbstractConverter.h"
19#include "flang/Lower/Bridge.h"
20#include "flang/Lower/DirectivesCommon.h"
21#include "flang/Optimizer/Builder/Todo.h"
22#include "flang/Parser/dump-parse-tree.h"
23#include "flang/Parser/parse-tree.h"
24#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
25
26namespace fir {
27class FirOpBuilder;
28} // namespace fir
29
30namespace Fortran {
31namespace lower {
32namespace omp {
33
49public:
52 const List<Clause> &clauses)
53 : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
54
55 // 'Unique' clauses: They can appear at most once in the clause list.
56 bool processBare(mlir::omp::BareClauseOps &result) const;
57 bool processBind(mlir::omp::BindClauseOps &result) const;
58 bool
59 processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
60 mlir::omp::LoopRelatedClauseOps &result,
61 llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
62 bool processDevice(lower::StatementContext &stmtCtx,
63 mlir::omp::DeviceClauseOps &result) const;
64 bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
65 bool processDistSchedule(lower::StatementContext &stmtCtx,
66 mlir::omp::DistScheduleClauseOps &result) const;
67 bool processFilter(lower::StatementContext &stmtCtx,
68 mlir::omp::FilterClauseOps &result) const;
69 bool processFinal(lower::StatementContext &stmtCtx,
70 mlir::omp::FinalClauseOps &result) const;
71 bool processHasDeviceAddr(
72 mlir::omp::HasDeviceAddrClauseOps &result,
73 llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
74 bool processHint(mlir::omp::HintClauseOps &result) const;
75 bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
76 bool processNowait(mlir::omp::NowaitClauseOps &result) const;
77 bool processNumTeams(lower::StatementContext &stmtCtx,
78 mlir::omp::NumTeamsClauseOps &result) const;
79 bool processNumThreads(lower::StatementContext &stmtCtx,
80 mlir::omp::NumThreadsClauseOps &result) const;
81 bool processOrder(mlir::omp::OrderClauseOps &result) const;
82 bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
83 bool processPriority(lower::StatementContext &stmtCtx,
84 mlir::omp::PriorityClauseOps &result) const;
85 bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
86 bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
87 bool processSchedule(lower::StatementContext &stmtCtx,
88 mlir::omp::ScheduleClauseOps &result) const;
89 bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
90 bool processThreadLimit(lower::StatementContext &stmtCtx,
91 mlir::omp::ThreadLimitClauseOps &result) const;
92 bool processUntied(mlir::omp::UntiedClauseOps &result) const;
93
94 bool processDetach(mlir::omp::DetachClauseOps &result) const;
95 // 'Repeatable' clauses: They can appear multiple times in the clause list.
96 bool processAligned(mlir::omp::AlignedClauseOps &result) const;
97 bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
98 bool processCopyin() const;
99 bool processCopyprivate(mlir::Location currentLocation,
100 mlir::omp::CopyprivateClauseOps &result) const;
101 bool processDepend(mlir::omp::DependClauseOps &result) const;
102 bool
103 processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
104 bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
105 mlir::omp::IfClauseOps &result) const;
106 bool processIsDevicePtr(
107 mlir::omp::IsDevicePtrClauseOps &result,
108 llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
109 bool
110 processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
111
112 // This method is used to process a map clause.
113 // The optional parameter mapSyms is used to store the original Fortran symbol
114 // for the map operands. It may be used later on to create the block_arguments
115 // for some of the directives that require it.
116 bool processMap(mlir::Location currentLocation,
118 mlir::omp::MapClauseOps &result,
119 llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms =
120 nullptr) const;
121 bool processMotionClauses(lower::StatementContext &stmtCtx,
122 mlir::omp::MapClauseOps &result);
123 bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const;
124 bool processReduction(
125 mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
126 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
127 bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
128 bool processUseDeviceAddr(
130 mlir::omp::UseDeviceAddrClauseOps &result,
131 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
132 bool processUseDevicePtr(
134 mlir::omp::UseDevicePtrClauseOps &result,
135 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
136
137 // Call this method for these clauses that should be supported but are not
138 // implemented yet. It triggers a compilation error if any of the given
139 // clauses is found.
140 template <typename... Ts>
141 void processTODO(mlir::Location currentLocation,
142 llvm::omp::Directive directive) const;
143
144private:
145 using ClauseIterator = List<Clause>::const_iterator;
146
148 template <typename T>
149 static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
150
154 template <typename T>
155 const T *findUniqueClause(const parser::CharBlock **source = nullptr) const;
156
159 template <typename T>
160 bool findRepeatableClause(
161 std::function<void(const T &, const parser::CharBlock &source)>
162 callbackFn) const;
163
165 template <typename T>
166 bool markClauseOccurrence(mlir::UnitAttr &result) const;
167
168 void processMapObjects(
169 lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
170 const omp::ObjectList &objects,
171 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
172 std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
173 llvm::SmallVectorImpl<mlir::Value> &mapVars,
174 llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
175
176 lower::AbstractConverter &converter;
178 List<Clause> clauses;
179};
180
181template <typename... Ts>
182void ClauseProcessor::processTODO(mlir::Location currentLocation,
183 llvm::omp::Directive directive) const {
184 auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) {
185 if (!x)
186 return;
187 TODO(currentLocation,
188 "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
189 " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
190 " construct");
191 };
192
193 for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
194 (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
195}
196
197template <typename T>
198ClauseProcessor::ClauseIterator
199ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
200 for (ClauseIterator it = begin; it != end; ++it) {
201 if (std::get_if<T>(&it->u))
202 return it;
203 }
204
205 return end;
206}
207
208template <typename T>
209const T *
210ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const {
211 ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
212 if (it != clauses.end()) {
213 if (source)
214 *source = &it->source;
215 return &std::get<T>(it->u);
216 }
217 return nullptr;
218}
219
220template <typename T>
221bool ClauseProcessor::findRepeatableClause(
222 std::function<void(const T &, const parser::CharBlock &source)> callbackFn)
223 const {
224 bool found = false;
225 ClauseIterator nextIt, endIt = clauses.end();
226 for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
227 nextIt = findClause<T>(it, endIt);
228
229 if (nextIt != endIt) {
230 callbackFn(std::get<T>(nextIt->u), nextIt->source);
231 found = true;
232 ++nextIt;
233 }
234 }
235 return found;
236}
237
238template <typename T>
239bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
240 if (findUniqueClause<T>()) {
241 result = converter.getFirOpBuilder().getUnitAttr();
242 return true;
243 }
244 return false;
245}
246
247} // namespace omp
248} // namespace lower
249} // namespace Fortran
250
251#endif // FORTRAN_LOWER_CLAUSEPROCESSOR_H
Definition: AbstractConverter.h:82
virtual fir::FirOpBuilder & getFirOpBuilder()=0
Get the OpBuilder.
Definition: StatementContext.h:46
Definition: ClauseProcessor.h:48
Definition: char-block.h:28
Definition: semantics.h:67
Definition: bit-population-count.h:20
Definition: AbstractConverter.h:31
Definition: PFTBuilder.h:216