FLANG
openmp-utils.h
1//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Common OpenMP utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
14#define FORTRAN_PARSER_OPENMP_UTILS_H
15
16#include "flang/Common/indirection.h"
17#include "flang/Common/template.h"
18#include "flang/Parser/parse-tree.h"
19#include "llvm/ADT/iterator_range.h"
20#include "llvm/Frontend/OpenMP/OMP.h"
21
22#include <cassert>
23#include <iterator>
24#include <tuple>
25#include <type_traits>
26#include <utility>
27#include <variant>
28#include <vector>
29
30namespace Fortran::parser::omp {
31
32template <typename T> constexpr auto addr_if(std::optional<T> &x) {
33 return x ? &*x : nullptr;
34}
35template <typename T> constexpr auto addr_if(const std::optional<T> &x) {
36 return x ? &*x : nullptr;
37}
38
39namespace detail {
41 static OmpDirectiveName MakeName(CharBlock source = {},
42 llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
44 name.source = source;
45 name.v = id;
46 return name;
47 }
48
49 static OmpDirectiveName GetOmpDirectiveName(const OmpDirectiveName &x) {
50 return x;
51 }
52
53 static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
54 return x.DirName();
55 }
56
57 static OmpDirectiveName GetOmpDirectiveName(const OpenMPSectionConstruct &x) {
58 if (auto &spec{std::get<std::optional<OmpDirectiveSpecification>>(x.t)}) {
59 return spec->DirName();
60 } else {
61 return MakeName({}, llvm::omp::Directive::OMPD_section);
62 }
63 }
64
65 static OmpDirectiveName GetOmpDirectiveName(
67 return x.DirName();
68 }
69
70 template <typename T>
71 static OmpDirectiveName GetOmpDirectiveName(const T &x) {
72 if constexpr (WrapperTrait<T>) {
73 return GetOmpDirectiveName(x.v);
74 } else if constexpr (TupleTrait<T>) {
75 if constexpr (std::is_base_of_v<OmpBlockConstruct, T>) {
76 return std::get<OmpBeginDirective>(x.t).DirName();
77 } else {
78 return GetFromTuple(
79 x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
80 }
81 } else if constexpr (UnionTrait<T>) {
82 return common::visit(
83 [](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
84 } else {
85 return MakeName();
86 }
87 }
88
89 template <typename... Ts, size_t... Is>
90 static OmpDirectiveName GetFromTuple(
91 const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
92 OmpDirectiveName name = MakeName();
93 auto accumulate = [&](const OmpDirectiveName &n) {
94 if (name.v == llvm::omp::Directive::OMPD_unknown) {
95 name = n;
96 } else {
97 assert(
98 n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
99 }
100 };
101 (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
102 return name;
103 }
104
105 template <typename T>
106 static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
107 return GetOmpDirectiveName(x.value());
108 }
109};
110} // namespace detail
111
112template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
113 return detail::DirectiveNameScope::GetOmpDirectiveName(x);
114}
115
116std::string GetUpperName(llvm::omp::Clause id, unsigned version);
117std::string GetUpperName(llvm::omp::Directive id, unsigned version);
118
120const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
121
122const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
123const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
124
125namespace detail {
126// Clauses with flangClass = "OmpObjectList".
127using MemberObjectListClauses =
128 std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
129 OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
130 OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
131 OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
132
133// Clauses with flangClass = "OmpSomeClause", and OmpObjectList a
134// member of tuple OmpSomeClause::t.
135using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
136 OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
137 OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
138 OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
139 OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
140
141// Does U have WrapperTrait (i.e. has a member 'v'), and if so, is T the
142// type of v?
143template <typename T, typename U, bool IsWrapper> struct WrappedInType {
144 static constexpr bool value{false};
145};
146
147template <typename T, typename U> struct WrappedInType<T, U, true> {
148 static constexpr bool value{std::is_same_v<T, decltype(U::v)>};
149};
150
151// Same as WrappedInType, but with a list of types Us. Satisfied if any
152// type U in Us satisfies WrappedInType<T, U>.
153template <typename...> struct WrappedInTypes;
154
155template <typename T> struct WrappedInTypes<T> {
156 static constexpr bool value{false};
157};
158
159template <typename T, typename U, typename... Us>
160struct WrappedInTypes<T, U, Us...> {
161 static constexpr bool value{WrappedInType<T, U, WrapperTrait<U>>::value ||
162 WrappedInTypes<T, Us...>::value};
163};
164
165// Same as WrappedInTypes, but takes type list in a form of a tuple or
166// a variant.
167template <typename...> struct WrappedInTupleOrVariant {
168 static constexpr bool value{false};
169};
170template <typename T, typename... Us>
171struct WrappedInTupleOrVariant<T, std::tuple<Us...>> {
172 static constexpr bool value{WrappedInTypes<T, Us...>::value};
173};
174template <typename T, typename... Us>
175struct WrappedInTupleOrVariant<T, std::variant<Us...>> {
176 static constexpr bool value{WrappedInTypes<T, Us...>::value};
177};
178template <typename T, typename U>
179constexpr bool WrappedInTupleOrVariantV{WrappedInTupleOrVariant<T, U>::value};
180} // namespace detail
181
182template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
183 using namespace detail;
184 static_assert(std::is_class_v<T>, "Unexpected argument type");
185
186 if constexpr (common::HasMember<T, decltype(OmpClause::u)>) {
187 if constexpr (common::HasMember<T, MemberObjectListClauses>) {
188 return &clause.v;
189 } else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
190 return &std::get<OmpObjectList>(clause.v.t);
191 } else {
192 return nullptr;
193 }
194 } else if constexpr (WrappedInTupleOrVariantV<T, TupleObjectListClauses>) {
195 return &std::get<OmpObjectList>(clause.t);
196 } else if constexpr (WrappedInTupleOrVariantV<T, decltype(OmpClause::u)>) {
197 return nullptr;
198 } else {
199 // The condition should be type-dependent, but it should always be false.
200 static_assert(sizeof(T) < 0 && "Unexpected argument type");
201 }
202}
203
204const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
205const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
206const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);
207
208template <typename T>
209const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
210 for (const OmpArgument &arg : spec.Arguments().v) {
211 if (auto *t{std::get_if<T>(&arg.u)}) {
212 return t;
213 }
214 }
215 return nullptr;
216}
217
218const OmpClause *FindClause(
219 const OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId);
220
221const BlockConstruct *GetFortranBlockConstruct(
222 const ExecutionPartConstruct &epc);
223const Block &GetInnermostExecPart(const Block &block);
224bool IsStrictlyStructuredBlock(const Block &block);
225
226const OmpCombinerExpression *GetCombinerExpr(const OmpReductionSpecifier &x);
227const OmpCombinerExpression *GetCombinerExpr(const OmpClause &x);
228const OmpInitializerExpression *GetInitializerExpr(const OmpClause &x);
229
231 std::vector<const OmpAllocateDirective *> dirs;
232 const ExecutionPartConstruct *body{nullptr};
233};
234
235OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x);
236
237template <typename R, typename = void, typename = void> struct is_range {
238 static constexpr bool value{false};
239};
240
241template <typename R>
242struct is_range<R, //
243 std::void_t<decltype(std::declval<R>().begin())>,
244 std::void_t<decltype(std::declval<R>().end())>> {
245 static constexpr bool value{true};
246};
247
248template <typename R> constexpr bool is_range_v = is_range<R>::value;
249
250// Iterate over a range of parser::Block::const_iterator's. When the end
251// of the range is reached, the iterator becomes invalid.
252// Treat BLOCK constructs as if they were transparent, i.e. as if the
253// BLOCK/ENDBLOCK statements, and the specification part contained within
254// were removed. The stepping determines whether the iterator steps "into"
255// DO loops and OpenMP loop constructs, or steps "over" them.
256//
257// Example: consecutive locations of the iterator:
258//
259// Step::Into Step::Over
260// block block
261// 1 => stmt1 1 => stmt1
262// block block
263// integer :: x integer :: x
264// 2 => stmt2 2 => stmt2
265// block block
266// end block end block
267// end block end block
268// 3 => do i = 1, n 3 => do i = 1, n
269// 4 => continue continue
270// end do end do
271// 5 => stmt3 4 => stmt3
272// end block end block
273//
274// 6 => <invalid> 5 => <invalid>
275//
276// The iterator is in a legal state (position) if it's at an
277// ExecutionPartConstruct that is not a BlockConstruct, or is invalid.
278struct ExecutionPartIterator {
279 enum class Step {
280 Into,
281 Over,
282 Default = Into,
283 };
284
285 using IteratorType = Block::const_iterator;
286 using IteratorRange = llvm::iterator_range<IteratorType>;
287
288 // An iterator range with a third iterator indicating a position inside
289 // the range.
290 struct IteratorGauge : public IteratorRange {
291 IteratorGauge(IteratorType b, IteratorType e)
292 : IteratorRange(b, e), at(b) {}
293 IteratorGauge(IteratorRange r) : IteratorRange(r), at(r.begin()) {}
294
295 bool atEnd() const { return at == end(); }
296 IteratorType at;
297 };
298
299 struct Construct {
300 Construct(IteratorType b, IteratorType e, const ExecutionPartConstruct *c)
301 : location(b, e), owner(c) {}
302 template <typename R>
303 Construct(const R &r, const ExecutionPartConstruct *c)
304 : location(r), owner(c) {}
305 Construct(const Construct &c) = default;
306 // The original range of the construct with the current position in it.
307 // The location.at is the construct currently being pointed at, or
308 // stepped into.
309 IteratorGauge location;
310 const ExecutionPartConstruct *owner;
311 };
312
313 ExecutionPartIterator() = default;
314
315 ExecutionPartIterator(IteratorType b, IteratorType e, Step s = Step::Default,
316 const ExecutionPartConstruct *c = nullptr)
317 : stepping_(s) {
318 stack_.emplace_back(b, e, c);
319 adjust();
320 }
321 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
322 ExecutionPartIterator(const R &range, Step stepping = Step::Default,
323 const ExecutionPartConstruct *construct = nullptr)
324 : ExecutionPartIterator(range.begin(), range.end(), stepping, construct) {
325 }
326
327 // Advance the iterator to the next legal position. If the current position
328 // is a DO-loop or a loop construct, step into the contained Block.
329 void step();
330
331 // Advance the iterator to the next legal position. If the current position
332 // is a DO-loop or a loop construct, step to the next legal position following
333 // the DO-loop or loop construct.
334 void next();
335
336 bool valid() const { return !stack_.empty(); }
337
338 const std::vector<Construct> &stack() const { return stack_; }
339 decltype(auto) operator*() const { return *at(); }
340 bool operator==(const ExecutionPartIterator &other) const {
341 if (valid() != other.valid()) {
342 return false;
343 }
344 // Invalid iterators are considered equal.
345 return !valid() ||
346 stack_.back().location.at == other.stack_.back().location.at;
347 }
348 bool operator!=(const ExecutionPartIterator &other) const {
349 return !(*this == other);
350 }
351
352 ExecutionPartIterator &operator++() {
353 if (stepping_ == Step::Into) {
354 step();
355 } else {
356 assert(stepping_ == Step::Over && "Unexpected stepping");
357 next();
358 }
359 return *this;
360 }
361
362 ExecutionPartIterator operator++(int) {
363 ExecutionPartIterator copy{*this};
364 operator++();
365 return copy;
366 }
367
368 using difference_type = IteratorType::difference_type;
369 using value_type = IteratorType::value_type;
370 using reference = IteratorType::reference;
371 using pointer = IteratorType::pointer;
372 using iterator_category = std::forward_iterator_tag;
373
374private:
375 IteratorType at() const { return stack_.back().location.at; };
376
377 // If the iterator is not at a legal location, keep advancing it until
378 // it lands at a legal location or becomes invalid.
379 void adjust();
380
381 const Step stepping_ = Step::Default;
382 std::vector<Construct> stack_;
383};
384
385template <typename Iterator = ExecutionPartIterator> struct ExecutionPartRange {
386 using Step = typename Iterator::Step;
387
388 ExecutionPartRange(Block::const_iterator begin, Block::const_iterator end,
389 Step stepping = Step::Default,
390 const ExecutionPartConstruct *owner = nullptr)
391 : begin_(begin, end, stepping, owner), end_() {}
392 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
393 ExecutionPartRange(const R &range, Step stepping = Step::Default,
394 const ExecutionPartConstruct *owner = nullptr)
395 : ExecutionPartRange(range.begin(), range.end(), stepping, owner) {}
396
397 Iterator begin() const { return begin_; }
398 Iterator end() const { return end_; }
399
400private:
401 Iterator begin_, end_;
402};
403
404struct LoopNestIterator : public ExecutionPartIterator {
405 LoopNestIterator() = default;
406
407 LoopNestIterator(IteratorType b, IteratorType e, Step s = Step::Default,
408 const ExecutionPartConstruct *c = nullptr)
409 : ExecutionPartIterator(b, e, s, c) {
410 adjust();
411 }
412 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
413 LoopNestIterator(const R &range, Step stepping = Step::Default,
414 const ExecutionPartConstruct *construct = nullptr)
415 : LoopNestIterator(range.begin(), range.end(), stepping, construct) {}
416
417 LoopNestIterator &operator++() {
418 ExecutionPartIterator::operator++();
419 adjust();
420 return *this;
421 }
422
423 LoopNestIterator operator++(int) {
424 LoopNestIterator copy{*this};
425 operator++();
426 return copy;
427 }
428
429private:
430 static bool isLoop(const ExecutionPartConstruct &c);
431
432 void adjust() {
433 while (valid() && !isLoop(**this)) {
434 ExecutionPartIterator::operator++();
435 }
436 }
437};
438
441
442} // namespace Fortran::parser::omp
443
444#endif // FORTRAN_PARSER_OPENMP_UTILS_H
Definition indirection.h:31
Definition char-block.h:28
Definition parse-tree.h:437
Definition parse-tree.h:2311
Definition parse-tree.h:554
Definition parse-tree.h:5266
Definition parse-tree.h:5403
Definition parse-tree.h:5131
Definition parse-tree.h:3480
Definition parse-tree.h:3525
Definition parse-tree.h:5438
Definition parse-tree.h:5412
Definition parse-tree.h:5145
Definition openmp-utils.h:278
Definition openmp-utils.h:385
Definition openmp-utils.h:230
Definition openmp-utils.h:143
Definition openmp-utils.h:237