FLANG
openmp-utils.h
1//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Common OpenMP utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
14#define FORTRAN_PARSER_OPENMP_UTILS_H
15
16#include "flang/Common/indirection.h"
17#include "flang/Common/template.h"
18#include "flang/Parser/parse-tree.h"
19#include "llvm/ADT/iterator_range.h"
20#include "llvm/Frontend/OpenMP/OMP.h"
21
22#include <cassert>
23#include <iterator>
24#include <tuple>
25#include <type_traits>
26#include <utility>
27#include <variant>
28#include <vector>
29
30namespace Fortran::parser::omp {
31
32template <typename T> constexpr auto addr_if(std::optional<T> &x) {
33 return x ? &*x : nullptr;
34}
35template <typename T> constexpr auto addr_if(const std::optional<T> &x) {
36 return x ? &*x : nullptr;
37}
38
39namespace detail {
41 static OmpDirectiveName MakeName(CharBlock source = {},
42 llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
44 name.source = source;
45 name.v = id;
46 return name;
47 }
48
49 static OmpDirectiveName GetOmpDirectiveName(const OmpDirectiveName &x) {
50 return x;
51 }
52
53 static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
54 return x.DirName();
55 }
56
57 static OmpDirectiveName GetOmpDirectiveName(const OpenMPSectionConstruct &x) {
58 if (auto &spec{std::get<std::optional<OmpDirectiveSpecification>>(x.t)}) {
59 return spec->DirName();
60 } else {
61 return MakeName({}, llvm::omp::Directive::OMPD_section);
62 }
63 }
64
65 static OmpDirectiveName GetOmpDirectiveName(
67 return x.DirName();
68 }
69
70 template <typename T>
71 static OmpDirectiveName GetOmpDirectiveName(const T &x) {
72 if constexpr (WrapperTrait<T>) {
73 return GetOmpDirectiveName(x.v);
74 } else if constexpr (TupleTrait<T>) {
75 if constexpr (std::is_base_of_v<OmpBlockConstruct, T>) {
76 return std::get<OmpBeginDirective>(x.t).DirName();
77 } else {
78 return GetFromTuple(
79 x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
80 }
81 } else if constexpr (UnionTrait<T>) {
82 return common::visit(
83 [](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
84 } else {
85 return MakeName();
86 }
87 }
88
89 template <typename... Ts, size_t... Is>
90 static OmpDirectiveName GetFromTuple(
91 const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
92 OmpDirectiveName name = MakeName();
93 auto accumulate = [&](const OmpDirectiveName &n) {
94 if (name.v == llvm::omp::Directive::OMPD_unknown) {
95 name = n;
96 } else {
97 assert(
98 n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
99 }
100 };
101 (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
102 return name;
103 }
104
105 template <typename T>
106 static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
107 return GetOmpDirectiveName(x.value());
108 }
109};
110} // namespace detail
111
112template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
113 return detail::DirectiveNameScope::GetOmpDirectiveName(x);
114}
115
116std::string GetUpperName(llvm::omp::Clause id, unsigned version);
117std::string GetUpperName(llvm::omp::Directive id, unsigned version);
118
120const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
121
122const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
123const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
124
125// Is the template argument "Statement<T>" for some T?
126template <typename T> struct IsStatement {
127 static constexpr bool value{false};
128};
129template <typename T> struct IsStatement<Statement<T>> {
130 static constexpr bool value{true};
131};
132
133std::optional<Label> GetStatementLabel(const ExecutionPartConstruct &x);
134std::optional<Label> GetFinalLabel(const OpenMPConstruct &x);
135
136namespace detail {
137// Clauses with flangClass = "OmpObjectList".
138using MemberObjectListClauses =
139 std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
140 OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
141 OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
142 OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
143
144// Clauses with flangClass = "OmpSomeClause", and OmpObjectList a
145// member of tuple OmpSomeClause::t.
146using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
147 OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
148 OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
149 OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
150 OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
151
152// Does U have WrapperTrait (i.e. has a member 'v'), and if so, is T the
153// type of v?
154template <typename T, typename U, bool IsWrapper> struct WrappedInType {
155 static constexpr bool value{false};
156};
157
158template <typename T, typename U> struct WrappedInType<T, U, true> {
159 static constexpr bool value{std::is_same_v<T, decltype(U::v)>};
160};
161
162// Same as WrappedInType, but with a list of types Us. Satisfied if any
163// type U in Us satisfies WrappedInType<T, U>.
164template <typename...> struct WrappedInTypes;
165
166template <typename T> struct WrappedInTypes<T> {
167 static constexpr bool value{false};
168};
169
170template <typename T, typename U, typename... Us>
171struct WrappedInTypes<T, U, Us...> {
172 static constexpr bool value{WrappedInType<T, U, WrapperTrait<U>>::value ||
173 WrappedInTypes<T, Us...>::value};
174};
175
176// Same as WrappedInTypes, but takes type list in a form of a tuple or
177// a variant.
178template <typename...> struct WrappedInTupleOrVariant {
179 static constexpr bool value{false};
180};
181template <typename T, typename... Us>
182struct WrappedInTupleOrVariant<T, std::tuple<Us...>> {
183 static constexpr bool value{WrappedInTypes<T, Us...>::value};
184};
185template <typename T, typename... Us>
186struct WrappedInTupleOrVariant<T, std::variant<Us...>> {
187 static constexpr bool value{WrappedInTypes<T, Us...>::value};
188};
189template <typename T, typename U>
190constexpr bool WrappedInTupleOrVariantV{WrappedInTupleOrVariant<T, U>::value};
191} // namespace detail
192
193template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
194 using namespace detail;
195 static_assert(std::is_class_v<T>, "Unexpected argument type");
196
197 if constexpr (common::HasMember<T, decltype(OmpClause::u)>) {
198 if constexpr (common::HasMember<T, MemberObjectListClauses>) {
199 return &clause.v;
200 } else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
201 return &std::get<OmpObjectList>(clause.v.t);
202 } else {
203 return nullptr;
204 }
205 } else if constexpr (WrappedInTupleOrVariantV<T, TupleObjectListClauses>) {
206 return &std::get<OmpObjectList>(clause.t);
207 } else if constexpr (WrappedInTupleOrVariantV<T, decltype(OmpClause::u)>) {
208 return nullptr;
209 } else {
210 // The condition should be type-dependent, but it should always be false.
211 static_assert(sizeof(T) < 0 && "Unexpected argument type");
212 }
213}
214
215const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
216const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
217const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);
218
219template <typename T>
220const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
221 for (const OmpArgument &arg : spec.Arguments().v) {
222 if (auto *t{std::get_if<T>(&arg.u)}) {
223 return t;
224 }
225 }
226 return nullptr;
227}
228
229const OmpClause *FindClause(
230 const OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId);
231
232const BlockConstruct *GetFortranBlockConstruct(
233 const ExecutionPartConstruct &epc);
234const Block &GetInnermostExecPart(const Block &block);
235bool IsStrictlyStructuredBlock(const Block &block);
236
237const OmpCombinerExpression *GetCombinerExpr(const OmpReductionSpecifier &x);
238const OmpCombinerExpression *GetCombinerExpr(const OmpClause &x);
239const OmpInitializerExpression *GetInitializerExpr(const OmpClause &x);
240
242 std::vector<const OmpAllocateDirective *> dirs;
243 const ExecutionPartConstruct *body{nullptr};
244};
245
246OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x);
247
248template <typename R, typename = void, typename = void> struct is_range {
249 static constexpr bool value{false};
250};
251
252template <typename R>
253struct is_range<R, //
254 std::void_t<decltype(std::declval<R>().begin())>,
255 std::void_t<decltype(std::declval<R>().end())>> {
256 static constexpr bool value{true};
257};
258
259template <typename R> constexpr bool is_range_v = is_range<R>::value;
260
261// Iterate over a range of parser::Block::const_iterator's. When the end
262// of the range is reached, the iterator becomes invalid.
263// Treat BLOCK constructs as if they were transparent, i.e. as if the
264// BLOCK/ENDBLOCK statements, and the specification part contained within
265// were removed. The stepping determines whether the iterator steps "into"
266// DO loops and OpenMP loop constructs, or steps "over" them.
267//
268// Example: consecutive locations of the iterator:
269//
270// Step::Into Step::Over
271// block block
272// 1 => stmt1 1 => stmt1
273// block block
274// integer :: x integer :: x
275// 2 => stmt2 2 => stmt2
276// block block
277// end block end block
278// end block end block
279// 3 => do i = 1, n 3 => do i = 1, n
280// 4 => continue continue
281// end do end do
282// 5 => stmt3 4 => stmt3
283// end block end block
284//
285// 6 => <invalid> 5 => <invalid>
286//
287// The iterator is in a legal state (position) if it's at an
288// ExecutionPartConstruct that is not a BlockConstruct, or is invalid.
289struct ExecutionPartIterator {
290 enum class Step {
291 Into,
292 Over,
293 Default = Into,
294 };
295
296 using IteratorType = Block::const_iterator;
297 using IteratorRange = llvm::iterator_range<IteratorType>;
298
299 // An iterator range with a third iterator indicating a position inside
300 // the range.
301 struct IteratorGauge : public IteratorRange {
302 IteratorGauge(IteratorType b, IteratorType e)
303 : IteratorRange(b, e), at(b) {}
304 IteratorGauge(IteratorRange r) : IteratorRange(r), at(r.begin()) {}
305
306 bool atEnd() const { return at == end(); }
307 IteratorType at;
308 };
309
310 struct Construct {
311 Construct(IteratorType b, IteratorType e, const ExecutionPartConstruct *c)
312 : location(b, e), owner(c) {}
313 template <typename R>
314 Construct(const R &r, const ExecutionPartConstruct *c)
315 : location(r), owner(c) {}
316 Construct(const Construct &c) = default;
317 // The original range of the construct with the current position in it.
318 // The location.at is the construct currently being pointed at, or
319 // stepped into.
320 IteratorGauge location;
321 const ExecutionPartConstruct *owner;
322 };
323
324 ExecutionPartIterator() = default;
325
326 ExecutionPartIterator(IteratorType b, IteratorType e, Step s = Step::Default,
327 const ExecutionPartConstruct *c = nullptr)
328 : stepping_(s) {
329 stack_.emplace_back(b, e, c);
330 adjust();
331 }
332 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
333 ExecutionPartIterator(const R &range, Step stepping = Step::Default,
334 const ExecutionPartConstruct *construct = nullptr)
335 : ExecutionPartIterator(range.begin(), range.end(), stepping, construct) {
336 }
337
338 // Advance the iterator to the next legal position. If the current position
339 // is a DO-loop or a loop construct, step into the contained Block.
340 void step();
341
342 // Advance the iterator to the next legal position. If the current position
343 // is a DO-loop or a loop construct, step to the next legal position following
344 // the DO-loop or loop construct.
345 void next();
346
347 bool valid() const { return !stack_.empty(); }
348
349 const std::vector<Construct> &stack() const { return stack_; }
350 decltype(auto) operator*() const { return *at(); }
351 bool operator==(const ExecutionPartIterator &other) const {
352 if (valid() != other.valid()) {
353 return false;
354 }
355 // Invalid iterators are considered equal.
356 return !valid() ||
357 stack_.back().location.at == other.stack_.back().location.at;
358 }
359 bool operator!=(const ExecutionPartIterator &other) const {
360 return !(*this == other);
361 }
362
363 ExecutionPartIterator &operator++() {
364 if (stepping_ == Step::Into) {
365 step();
366 } else {
367 assert(stepping_ == Step::Over && "Unexpected stepping");
368 next();
369 }
370 return *this;
371 }
372
373 ExecutionPartIterator operator++(int) {
374 ExecutionPartIterator copy{*this};
375 operator++();
376 return copy;
377 }
378
379 using difference_type = IteratorType::difference_type;
380 using value_type = IteratorType::value_type;
381 using reference = IteratorType::reference;
382 using pointer = IteratorType::pointer;
383 using iterator_category = std::forward_iterator_tag;
384
385private:
386 IteratorType at() const { return stack_.back().location.at; };
387
388 // If the iterator is not at a legal location, keep advancing it until
389 // it lands at a legal location or becomes invalid.
390 void adjust();
391
392 const Step stepping_ = Step::Default;
393 std::vector<Construct> stack_;
394};
395
396template <typename Iterator = ExecutionPartIterator> struct ExecutionPartRange {
397 using Step = typename Iterator::Step;
398
399 ExecutionPartRange(Block::const_iterator begin, Block::const_iterator end,
400 Step stepping = Step::Default,
401 const ExecutionPartConstruct *owner = nullptr)
402 : begin_(begin, end, stepping, owner), end_() {}
403 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
404 ExecutionPartRange(const R &range, Step stepping = Step::Default,
405 const ExecutionPartConstruct *owner = nullptr)
406 : ExecutionPartRange(range.begin(), range.end(), stepping, owner) {}
407
408 Iterator begin() const { return begin_; }
409 Iterator end() const { return end_; }
410
411private:
412 Iterator begin_, end_;
413};
414
415struct LoopNestIterator : public ExecutionPartIterator {
416 LoopNestIterator() = default;
417
418 LoopNestIterator(IteratorType b, IteratorType e, Step s = Step::Default,
419 const ExecutionPartConstruct *c = nullptr)
420 : ExecutionPartIterator(b, e, s, c) {
421 adjust();
422 }
423 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
424 LoopNestIterator(const R &range, Step stepping = Step::Default,
425 const ExecutionPartConstruct *construct = nullptr)
426 : LoopNestIterator(range.begin(), range.end(), stepping, construct) {}
427
428 LoopNestIterator &operator++() {
429 ExecutionPartIterator::operator++();
430 adjust();
431 return *this;
432 }
433
434 LoopNestIterator operator++(int) {
435 LoopNestIterator copy{*this};
436 operator++();
437 return copy;
438 }
439
440private:
441 static bool isLoop(const ExecutionPartConstruct &c);
442
443 void adjust() {
444 while (valid() && !isLoop(**this)) {
445 ExecutionPartIterator::operator++();
446 }
447 }
448};
449
452
453} // namespace Fortran::parser::omp
454
455#endif // FORTRAN_PARSER_OPENMP_UTILS_H
Definition indirection.h:31
Definition char-block.h:28
Definition parse-tree.h:437
Definition parse-tree.h:2311
Definition parse-tree.h:554
Definition parse-tree.h:5260
Definition parse-tree.h:5397
Definition parse-tree.h:5125
Definition parse-tree.h:3474
Definition parse-tree.h:3519
Definition parse-tree.h:5432
Definition parse-tree.h:5406
Definition parse-tree.h:5139
Definition parse-tree.h:359
Definition openmp-utils.h:289
Definition openmp-utils.h:396
Definition openmp-utils.h:126
Definition openmp-utils.h:241
Definition openmp-utils.h:154
Definition openmp-utils.h:248