FLANG
openmp-utils.h
1//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Common OpenMP utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
14#define FORTRAN_PARSER_OPENMP_UTILS_H
15
16#include "flang/Common/indirection.h"
17#include "flang/Common/template.h"
18#include "flang/Parser/parse-tree.h"
19#include "llvm/ADT/iterator_range.h"
20#include "llvm/Frontend/OpenMP/OMP.h"
21
22#include <cassert>
23#include <iterator>
24#include <tuple>
25#include <type_traits>
26#include <utility>
27#include <variant>
28#include <vector>
29
30namespace Fortran::parser::omp {
31
32template <typename T> constexpr auto addr_if(std::optional<T> &x) {
33 return x ? &*x : nullptr;
34}
35template <typename T> constexpr auto addr_if(const std::optional<T> &x) {
36 return x ? &*x : nullptr;
37}
38
39namespace detail {
41 using ODS = OmpDirectiveSpecification;
42 template <typename T> static const ODS &GetODS(const T &x) {
43 if constexpr ( //
44 std::is_base_of_v<OmpBlockConstruct, T> ||
45 std::is_same_v<OpenMPLoopConstruct, T> ||
46 std::is_same_v<OpenMPSectionsConstruct, T>) {
47 return x.BeginDir();
48 } else if constexpr (WrapperTrait<T>) {
49 return GetODS(x.v);
50 } else if constexpr (UnionTrait<T>) {
51 return std::visit(
52 [](auto &&s) -> decltype(auto) { return GetODS(s); }, x.u);
53 } else {
54 static_assert(std::is_same_v<OpenMPSectionConstruct, T>);
55 llvm_unreachable("This function does not work for SECTION");
56 }
57 }
58 static inline const ODS &GetODS(const ODS &x) { return x; }
59};
60} // namespace detail
61
62const OmpDirectiveSpecification &GetOmpDirectiveSpecification(
63 const OpenMPConstruct &x);
64const OmpDirectiveSpecification &GetOmpDirectiveSpecification(
66
67namespace detail {
69 static OmpDirectiveName MakeName(CharBlock source = {},
70 llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
72 name.source = source;
73 name.v = id;
74 return name;
75 }
76
77 static OmpDirectiveName GetOmpDirectiveName(const OmpDirectiveName &x) {
78 return x;
79 }
80
81 static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
82 return x.DirName();
83 }
84
85 static OmpDirectiveName GetOmpDirectiveName(const OpenMPSectionConstruct &x) {
86 if (auto &spec{std::get<std::optional<OmpDirectiveSpecification>>(x.t)}) {
87 return spec->DirName();
88 } else {
89 return MakeName({}, llvm::omp::Directive::OMPD_section);
90 }
91 }
92
93 static OmpDirectiveName GetOmpDirectiveName(
95 return x.DirName();
96 }
97
98 template <typename T>
99 static OmpDirectiveName GetOmpDirectiveName(const T &x) {
100 if constexpr (WrapperTrait<T>) {
101 return GetOmpDirectiveName(x.v);
102 } else if constexpr (TupleTrait<T>) {
103 if constexpr (std::is_base_of_v<OmpBlockConstruct, T>) {
104 return std::get<OmpBeginDirective>(x.t).DirName();
105 } else {
106 return GetFromTuple(
107 x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
108 }
109 } else if constexpr (UnionTrait<T>) {
110 return common::visit(
111 [](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
112 } else {
113 return MakeName();
114 }
115 }
116
117 template <typename... Ts, size_t... Is>
118 static OmpDirectiveName GetFromTuple(
119 const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
120 OmpDirectiveName name = MakeName();
121 auto accumulate = [&](const OmpDirectiveName &n) {
122 if (name.v == llvm::omp::Directive::OMPD_unknown) {
123 name = n;
124 } else {
125 assert(
126 n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
127 }
128 };
129 (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
130 return name;
131 }
132
133 template <typename T>
134 static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
135 return GetOmpDirectiveName(x.value());
136 }
137};
138} // namespace detail
139
140template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
141 return detail::DirectiveNameScope::GetOmpDirectiveName(x);
142}
143
144std::string GetUpperName(llvm::omp::Clause id, unsigned version);
145std::string GetUpperName(llvm::omp::Directive id, unsigned version);
146
148const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
149
150const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
151const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
152
153namespace detail {
154// Clauses with flangClass = "OmpObjectList".
155using MemberObjectListClauses =
156 std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
157 OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
158 OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
159 OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
160
161// Clauses with flangClass = "OmpSomeClause", and OmpObjectList a
162// member of tuple OmpSomeClause::t.
163using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
164 OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
165 OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
166 OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
167 OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
168
169// Does U have WrapperTrait (i.e. has a member 'v'), and if so, is T the
170// type of v?
171template <typename T, typename U, bool IsWrapper> struct WrappedInType {
172 static constexpr bool value{false};
173};
174
175template <typename T, typename U> struct WrappedInType<T, U, true> {
176 static constexpr bool value{std::is_same_v<T, decltype(U::v)>};
177};
178
179// Same as WrappedInType, but with a list of types Us. Satisfied if any
180// type U in Us satisfies WrappedInType<T, U>.
181template <typename...> struct WrappedInTypes;
182
183template <typename T> struct WrappedInTypes<T> {
184 static constexpr bool value{false};
185};
186
187template <typename T, typename U, typename... Us>
188struct WrappedInTypes<T, U, Us...> {
189 static constexpr bool value{WrappedInType<T, U, WrapperTrait<U>>::value ||
190 WrappedInTypes<T, Us...>::value};
191};
192
193// Same as WrappedInTypes, but takes type list in a form of a tuple or
194// a variant.
195template <typename...> struct WrappedInTupleOrVariant {
196 static constexpr bool value{false};
197};
198template <typename T, typename... Us>
199struct WrappedInTupleOrVariant<T, std::tuple<Us...>> {
200 static constexpr bool value{WrappedInTypes<T, Us...>::value};
201};
202template <typename T, typename... Us>
203struct WrappedInTupleOrVariant<T, std::variant<Us...>> {
204 static constexpr bool value{WrappedInTypes<T, Us...>::value};
205};
206template <typename T, typename U>
207constexpr bool WrappedInTupleOrVariantV{WrappedInTupleOrVariant<T, U>::value};
208} // namespace detail
209
210template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
211 using namespace detail;
212 static_assert(std::is_class_v<T>, "Unexpected argument type");
213
214 if constexpr (common::HasMember<T, decltype(OmpClause::u)>) {
215 if constexpr (common::HasMember<T, MemberObjectListClauses>) {
216 return &clause.v;
217 } else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
218 return &std::get<OmpObjectList>(clause.v.t);
219 } else {
220 return nullptr;
221 }
222 } else if constexpr (WrappedInTupleOrVariantV<T, TupleObjectListClauses>) {
223 return &std::get<OmpObjectList>(clause.t);
224 } else if constexpr (WrappedInTupleOrVariantV<T, decltype(OmpClause::u)>) {
225 return nullptr;
226 } else {
227 // The condition should be type-dependent, but it should always be false.
228 static_assert(sizeof(T) < 0 && "Unexpected argument type");
229 }
230}
231
232const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
233const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
234const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);
235
236template <typename T>
237const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
238 for (const OmpArgument &arg : spec.Arguments().v) {
239 if (auto *t{std::get_if<T>(&arg.u)}) {
240 return t;
241 }
242 }
243 return nullptr;
244}
245
246const OmpClause *FindClause(
247 const OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId);
248
249const BlockConstruct *GetFortranBlockConstruct(
250 const ExecutionPartConstruct &epc);
251const Block &GetInnermostExecPart(const Block &block);
252bool IsStrictlyStructuredBlock(const Block &block);
253
254const OmpCombinerExpression *GetCombinerExpr(const OmpReductionSpecifier &x);
255const OmpCombinerExpression *GetCombinerExpr(const OmpClause &x);
256const OmpInitializerExpression *GetInitializerExpr(const OmpClause &x);
257
259 std::vector<const OmpAllocateDirective *> dirs;
260 const ExecutionPartConstruct *body{nullptr};
261};
262
263OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x);
264
265template <typename R, typename = void, typename = void> struct is_range {
266 static constexpr bool value{false};
267};
268
269template <typename R>
270struct is_range<R, //
271 std::void_t<decltype(std::declval<R>().begin())>,
272 std::void_t<decltype(std::declval<R>().end())>> {
273 static constexpr bool value{true};
274};
275
276template <typename R> constexpr bool is_range_v = is_range<R>::value;
277
278// Iterate over a range of parser::Block::const_iterator's. When the end
279// of the range is reached, the iterator becomes invalid.
280// Treat BLOCK constructs as if they were transparent, i.e. as if the
281// BLOCK/ENDBLOCK statements, and the specification part contained within
282// were removed. The stepping determines whether the iterator steps "into"
283// DO loops and OpenMP loop constructs, or steps "over" them.
284//
285// Example: consecutive locations of the iterator:
286//
287// Step::Into Step::Over
288// block block
289// 1 => stmt1 1 => stmt1
290// block block
291// integer :: x integer :: x
292// 2 => stmt2 2 => stmt2
293// block block
294// end block end block
295// end block end block
296// 3 => do i = 1, n 3 => do i = 1, n
297// 4 => continue continue
298// end do end do
299// 5 => stmt3 4 => stmt3
300// end block end block
301//
302// 6 => <invalid> 5 => <invalid>
303//
304// The iterator is in a legal state (position) if it's at an
305// ExecutionPartConstruct that is not a BlockConstruct, or is invalid.
306struct ExecutionPartIterator {
307 enum class Step {
308 Into,
309 Over,
310 Default = Into,
311 };
312
313 using IteratorType = Block::const_iterator;
314 using IteratorRange = llvm::iterator_range<IteratorType>;
315
316 // An iterator range with a third iterator indicating a position inside
317 // the range.
318 struct IteratorGauge : public IteratorRange {
319 IteratorGauge(IteratorType b, IteratorType e)
320 : IteratorRange(b, e), at(b) {}
321 IteratorGauge(IteratorRange r) : IteratorRange(r), at(r.begin()) {}
322
323 bool atEnd() const { return at == end(); }
324 IteratorType at;
325 };
326
327 struct Construct {
328 Construct(IteratorType b, IteratorType e, const ExecutionPartConstruct *c)
329 : location(b, e), owner(c) {}
330 template <typename R>
331 Construct(const R &r, const ExecutionPartConstruct *c)
332 : location(r), owner(c) {}
333 Construct(const Construct &c) = default;
334 // The original range of the construct with the current position in it.
335 // The location.at is the construct currently being pointed at, or
336 // stepped into.
337 IteratorGauge location;
338 const ExecutionPartConstruct *owner;
339 };
340
341 ExecutionPartIterator() = default;
342
343 ExecutionPartIterator(IteratorType b, IteratorType e, Step s = Step::Default,
344 const ExecutionPartConstruct *c = nullptr)
345 : stepping_(s) {
346 stack_.emplace_back(b, e, c);
347 adjust();
348 }
349 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
350 ExecutionPartIterator(const R &range, Step stepping = Step::Default,
351 const ExecutionPartConstruct *construct = nullptr)
352 : ExecutionPartIterator(range.begin(), range.end(), stepping, construct) {
353 }
354
355 // Advance the iterator to the next legal position. If the current position
356 // is a DO-loop or a loop construct, step into the contained Block.
357 void step();
358
359 // Advance the iterator to the next legal position. If the current position
360 // is a DO-loop or a loop construct, step to the next legal position following
361 // the DO-loop or loop construct.
362 void next();
363
364 bool valid() const { return !stack_.empty(); }
365
366 const std::vector<Construct> &stack() const { return stack_; }
367 decltype(auto) operator*() const { return *at(); }
368 bool operator==(const ExecutionPartIterator &other) const {
369 if (valid() != other.valid()) {
370 return false;
371 }
372 // Invalid iterators are considered equal.
373 return !valid() ||
374 stack_.back().location.at == other.stack_.back().location.at;
375 }
376 bool operator!=(const ExecutionPartIterator &other) const {
377 return !(*this == other);
378 }
379
380 ExecutionPartIterator &operator++() {
381 if (stepping_ == Step::Into) {
382 step();
383 } else {
384 assert(stepping_ == Step::Over && "Unexpected stepping");
385 next();
386 }
387 return *this;
388 }
389
390 ExecutionPartIterator operator++(int) {
391 ExecutionPartIterator copy{*this};
392 operator++();
393 return copy;
394 }
395
396 using difference_type = IteratorType::difference_type;
397 using value_type = IteratorType::value_type;
398 using reference = IteratorType::reference;
399 using pointer = IteratorType::pointer;
400 using iterator_category = std::forward_iterator_tag;
401
402private:
403 IteratorType at() const { return stack_.back().location.at; };
404
405 // If the iterator is not at a legal location, keep advancing it until
406 // it lands at a legal location or becomes invalid.
407 void adjust();
408
409 const Step stepping_ = Step::Default;
410 std::vector<Construct> stack_;
411};
412
413template <typename Iterator = ExecutionPartIterator> struct ExecutionPartRange {
414 using Step = typename Iterator::Step;
415
416 ExecutionPartRange(Block::const_iterator begin, Block::const_iterator end,
417 Step stepping = Step::Default,
418 const ExecutionPartConstruct *owner = nullptr)
419 : begin_(begin, end, stepping, owner), end_() {}
420 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
421 ExecutionPartRange(const R &range, Step stepping = Step::Default,
422 const ExecutionPartConstruct *owner = nullptr)
423 : ExecutionPartRange(range.begin(), range.end(), stepping, owner) {}
424
425 Iterator begin() const { return begin_; }
426 Iterator end() const { return end_; }
427
428private:
429 Iterator begin_, end_;
430};
431
432struct LoopNestIterator : public ExecutionPartIterator {
433 LoopNestIterator() = default;
434
435 LoopNestIterator(IteratorType b, IteratorType e, Step s = Step::Default,
436 const ExecutionPartConstruct *c = nullptr)
437 : ExecutionPartIterator(b, e, s, c) {
438 adjust();
439 }
440 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
441 LoopNestIterator(const R &range, Step stepping = Step::Default,
442 const ExecutionPartConstruct *construct = nullptr)
443 : LoopNestIterator(range.begin(), range.end(), stepping, construct) {}
444
445 LoopNestIterator &operator++() {
446 ExecutionPartIterator::operator++();
447 adjust();
448 return *this;
449 }
450
451 LoopNestIterator operator++(int) {
452 LoopNestIterator copy{*this};
453 operator++();
454 return copy;
455 }
456
457private:
458 static bool isLoop(const ExecutionPartConstruct &c);
459
460 void adjust() {
461 while (valid() && !isLoop(**this)) {
462 ExecutionPartIterator::operator++();
463 }
464 }
465};
466
469
470} // namespace Fortran::parser::omp
471
472#endif // FORTRAN_PARSER_OPENMP_UTILS_H
Definition indirection.h:31
Definition char-block.h:28
Definition parse-tree.h:437
Definition parse-tree.h:2323
Definition parse-tree.h:554
Definition parse-tree.h:5278
Definition parse-tree.h:5415
Definition parse-tree.h:5143
Definition parse-tree.h:3492
Definition parse-tree.h:5048
Definition parse-tree.h:3537
Definition parse-tree.h:5450
Definition parse-tree.h:5424
Definition parse-tree.h:5157
Definition openmp-utils.h:306
Definition openmp-utils.h:413
Definition openmp-utils.h:258
Definition openmp-utils.h:171
Definition openmp-utils.h:265