FLANG
openmp-utils.h
1//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Common OpenMP utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
14#define FORTRAN_PARSER_OPENMP_UTILS_H
15
16#include "flang/Common/indirection.h"
17#include "flang/Common/template.h"
18#include "flang/Parser/parse-tree.h"
19#include "llvm/ADT/iterator_range.h"
20#include "llvm/Frontend/OpenMP/OMP.h"
21
22#include <cassert>
23#include <iterator>
24#include <tuple>
25#include <type_traits>
26#include <utility>
27#include <variant>
28#include <vector>
29
30namespace Fortran::parser::omp {
31
32template <typename T> constexpr auto addr_if(std::optional<T> &x) {
33 return x ? &*x : nullptr;
34}
35template <typename T> constexpr auto addr_if(const std::optional<T> &x) {
36 return x ? &*x : nullptr;
37}
38
39const parser::Designator *GetDesignatorFromObj(const parser::OmpObject &object);
40const parser::DataRef *GetDataRefFromObj(const parser::OmpObject &object);
41const parser::ArrayElement *GetArrayElementFromObj(
42 const parser::OmpObject &object);
43std::optional<parser::CharBlock> GetObjectSource(
44 const parser::OmpObject &object);
45const parser::OmpObject *GetArgumentObject(const parser::OmpArgument &argument);
46
47const OmpDirectiveSpecification &GetOmpDirectiveSpecification(
48 const OpenMPConstruct &x);
49const OmpDirectiveSpecification &GetOmpDirectiveSpecification(
50 const OpenMPDeclarativeConstruct &x);
51
52namespace detail {
54 static OmpDirectiveName MakeName(CharBlock source = {},
55 llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
57 name.source = source;
58 name.v = id;
59 return name;
60 }
61
62 static OmpDirectiveName GetOmpDirectiveName(const OmpDirectiveName &x) {
63 return x;
64 }
65
66 static OmpDirectiveName GetOmpDirectiveName(const OpenMPSectionConstruct &x) {
67 if (auto &spec{std::get<std::optional<OmpDirectiveSpecification>>(x.t)}) {
68 return spec->DirName();
69 } else {
70 return MakeName({}, llvm::omp::Directive::OMPD_section);
71 }
72 }
73
74 static OmpDirectiveName GetOmpDirectiveName(
76 return x.DirName();
77 }
78
79 template <typename T>
80 static OmpDirectiveName GetOmpDirectiveName(const T &x) {
81 if constexpr (WrapperTrait<T>) {
82 return GetOmpDirectiveName(x.v);
83 } else if constexpr (TupleTrait<T>) {
84 if constexpr (std::is_base_of_v<OmpBlockConstruct, T>) {
85 return std::get<OmpBeginDirective>(x.t).DirName();
86 } else {
87 return GetFromTuple(
88 x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
89 }
90 } else if constexpr (UnionTrait<T>) {
91 return common::visit(
92 [](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
93 } else {
94 return MakeName();
95 }
96 }
97
98 template <typename... Ts, size_t... Is>
99 static OmpDirectiveName GetFromTuple(
100 const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
101 OmpDirectiveName name = MakeName();
102 auto accumulate = [&](const OmpDirectiveName &n) {
103 if (name.v == llvm::omp::Directive::OMPD_unknown) {
104 name = n;
105 } else {
106 assert(
107 n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
108 }
109 };
110 (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
111 return name;
112 }
113
114 template <typename T>
115 static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
116 return GetOmpDirectiveName(x.value());
117 }
118};
119} // namespace detail
120
121template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
122 return detail::DirectiveNameScope::GetOmpDirectiveName(x);
123}
124
125std::string GetUpperName(llvm::omp::Clause id, unsigned version);
126std::string GetUpperName(llvm::omp::Directive id, unsigned version);
127
129const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
130
131const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
132const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
133
134namespace detail {
136 template <typename T> static const OmpObjectList *Get(const T &x) {
137 if constexpr (std::is_same_v<OmpObjectList, T>) {
138 return &x;
139 } else if constexpr (WrapperTrait<T>) {
140 return Get(x.v);
141 } else if constexpr (UnionTrait<T>) {
142 return std::visit([](auto &&s) { return Get(s); }, x.u);
143 } else if constexpr (TupleTrait<T>) {
144 return GetFromTuple(
145 x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
146 } else if constexpr (ConstraintTrait<T>) {
147 return Get(x.thing);
148 } else {
149 return nullptr;
150 }
151 }
152
153 template <typename T>
154 static const OmpObjectList *Get(const common::Indirection<T> &x) {
155 return Get(x.value());
156 }
157
158 template <typename... Ts, size_t... Is>
159 static const OmpObjectList *GetFromTuple(
160 const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
161 const OmpObjectList *objects{nullptr};
162 ((objects = objects ? objects : Get(std::get<Is>(t))), ...);
163 return objects;
164 }
165};
166} // namespace detail
167
168template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
169 static_assert(std::is_class_v<T>, "Unexpected argument type");
170 return detail::OmpObjectListScope::Get(clause);
171}
172
173template <typename T>
174const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
175 for (const OmpArgument &arg : spec.Arguments().v) {
176 if (auto *t{std::get_if<T>(&arg.u)}) {
177 return t;
178 }
179 }
180 return nullptr;
181}
182
183const OmpClause *FindClause(
184 const OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId);
185
186const BlockConstruct *GetFortranBlockConstruct(
187 const ExecutionPartConstruct &epc);
188const Block &GetInnermostExecPart(const Block &block);
189bool IsStrictlyStructuredBlock(const Block &block);
190
191const OmpCombinerExpression *GetCombinerExpr(const OmpReductionSpecifier &x);
192const OmpCombinerExpression *GetCombinerExpr(const OmpClause &x);
193const OmpInitializerExpression *GetInitializerExpr(const OmpClause &x);
194
196 std::vector<const OmpAllocateDirective *> dirs;
197 const ExecutionPartConstruct *body{nullptr};
198};
199
200OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x);
201
202template <typename R, typename = void, typename = void> struct is_range {
203 static constexpr bool value{false};
204};
205
206template <typename R>
207struct is_range<R, //
208 std::void_t<decltype(std::declval<R>().begin())>,
209 std::void_t<decltype(std::declval<R>().end())>> {
210 static constexpr bool value{true};
211};
212
213template <typename R> constexpr bool is_range_v = is_range<R>::value;
214
215// Iterate over a range of parser::Block::const_iterator's. When the end
216// of the range is reached, the iterator becomes invalid.
217// Treat BLOCK constructs as if they were transparent, i.e. as if the
218// BLOCK/ENDBLOCK statements, and the specification part contained within
219// were removed. The stepping determines whether the iterator steps "into"
220// DO loops and OpenMP loop constructs, or steps "over" them.
221//
222// Example: consecutive locations of the iterator:
223//
224// Step::Into Step::Over
225// block block
226// 1 => stmt1 1 => stmt1
227// block block
228// integer :: x integer :: x
229// 2 => stmt2 2 => stmt2
230// block block
231// end block end block
232// end block end block
233// 3 => do i = 1, n 3 => do i = 1, n
234// 4 => continue continue
235// end do end do
236// 5 => stmt3 4 => stmt3
237// end block end block
238//
239// 6 => <invalid> 5 => <invalid>
240//
241// The iterator is in a legal state (position) if it's at an
242// ExecutionPartConstruct that is not a BlockConstruct, or is invalid.
243struct ExecutionPartIterator {
244 enum class Step {
245 Into,
246 Over,
247 Default = Into,
248 };
249
250 using IteratorType = Block::const_iterator;
251 using IteratorRange = llvm::iterator_range<IteratorType>;
252
253 // An iterator range with a third iterator indicating a position inside
254 // the range.
255 struct IteratorGauge : public IteratorRange {
256 IteratorGauge(IteratorType b, IteratorType e)
257 : IteratorRange(b, e), at(b) {}
258 IteratorGauge(IteratorRange r) : IteratorRange(r), at(r.begin()) {}
259
260 bool atEnd() const { return at == end(); }
261 IteratorType at;
262 };
263
264 struct Construct {
265 Construct(IteratorType b, IteratorType e, const ExecutionPartConstruct *c)
266 : location(b, e), owner(c) {}
267 template <typename R>
268 Construct(const R &r, const ExecutionPartConstruct *c)
269 : location(r), owner(c) {}
270 Construct(const Construct &c) = default;
271 // The original range of the construct with the current position in it.
272 // The location.at is the construct currently being pointed at, or
273 // stepped into.
274 IteratorGauge location;
275 const ExecutionPartConstruct *owner;
276 };
277
278 ExecutionPartIterator() = default;
279
280 ExecutionPartIterator(IteratorType b, IteratorType e, Step s = Step::Default,
281 const ExecutionPartConstruct *c = nullptr)
282 : stepping_(s) {
283 stack_.emplace_back(b, e, c);
284 adjust();
285 }
286 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
287 ExecutionPartIterator(const R &range, Step stepping = Step::Default,
288 const ExecutionPartConstruct *construct = nullptr)
289 : ExecutionPartIterator(range.begin(), range.end(), stepping, construct) {
290 }
291
292 // Advance the iterator to the next legal position. If the current position
293 // is a DO-loop or a loop construct, step into the contained Block.
294 void step();
295
296 // Advance the iterator to the next legal position. If the current position
297 // is a DO-loop or a loop construct, step to the next legal position following
298 // the DO-loop or loop construct.
299 void next();
300
301 bool valid() const { return !stack_.empty(); }
302
303 const std::vector<Construct> &stack() const { return stack_; }
304 decltype(auto) operator*() const { return *at(); }
305 bool operator==(const ExecutionPartIterator &other) const {
306 if (valid() != other.valid()) {
307 return false;
308 }
309 // Invalid iterators are considered equal.
310 return !valid() ||
311 stack_.back().location.at == other.stack_.back().location.at;
312 }
313 bool operator!=(const ExecutionPartIterator &other) const {
314 return !(*this == other);
315 }
316
317 ExecutionPartIterator &operator++() {
318 if (stepping_ == Step::Into) {
319 step();
320 } else {
321 assert(stepping_ == Step::Over && "Unexpected stepping");
322 next();
323 }
324 return *this;
325 }
326
327 ExecutionPartIterator operator++(int) {
328 ExecutionPartIterator copy{*this};
329 operator++();
330 return copy;
331 }
332
333 using difference_type = IteratorType::difference_type;
334 using value_type = IteratorType::value_type;
335 using reference = IteratorType::reference;
336 using pointer = IteratorType::pointer;
337 using iterator_category = std::forward_iterator_tag;
338
339private:
340 IteratorType at() const { return stack_.back().location.at; };
341
342 // If the iterator is not at a legal location, keep advancing it until
343 // it lands at a legal location or becomes invalid.
344 void adjust();
345
346 const Step stepping_ = Step::Default;
347 std::vector<Construct> stack_;
348};
349
350template <typename Iterator = ExecutionPartIterator> struct ExecutionPartRange {
351 using Step = typename Iterator::Step;
352
353 ExecutionPartRange(Block::const_iterator begin, Block::const_iterator end,
354 Step stepping = Step::Default,
355 const ExecutionPartConstruct *owner = nullptr)
356 : begin_(begin, end, stepping, owner), end_() {}
357 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
358 ExecutionPartRange(const R &range, Step stepping = Step::Default,
359 const ExecutionPartConstruct *owner = nullptr)
360 : ExecutionPartRange(range.begin(), range.end(), stepping, owner) {}
361
362 Iterator begin() const { return begin_; }
363 Iterator end() const { return end_; }
364
365private:
366 Iterator begin_, end_;
367};
368
369struct LoopNestIterator : public ExecutionPartIterator {
370 LoopNestIterator() = default;
371
372 LoopNestIterator(IteratorType b, IteratorType e, Step s = Step::Default,
373 const ExecutionPartConstruct *c = nullptr)
374 : ExecutionPartIterator(b, e, s, c) {
375 adjust();
376 }
377 template <typename R, typename = std::enable_if_t<is_range_v<R>>>
378 LoopNestIterator(const R &range, Step stepping = Step::Default,
379 const ExecutionPartConstruct *construct = nullptr)
380 : LoopNestIterator(range.begin(), range.end(), stepping, construct) {}
381
382 LoopNestIterator &operator++() {
383 ExecutionPartIterator::operator++();
384 adjust();
385 return *this;
386 }
387
388 LoopNestIterator operator++(int) {
389 LoopNestIterator copy{*this};
390 operator++();
391 return copy;
392 }
393
394private:
395 static bool isLoop(const ExecutionPartConstruct &c);
396
397 void adjust() {
398 while (valid() && !isLoop(**this)) {
399 ExecutionPartIterator::operator++();
400 }
401 }
402};
403
406
407} // namespace Fortran::parser::omp
408
409#endif // FORTRAN_PARSER_OPENMP_UTILS_H
Definition indirection.h:31
Definition char-block.h:28
Definition parse-tree.h:439
Definition parse-tree.h:2325
Definition parse-tree.h:556
Definition parse-tree.h:5321
Definition parse-tree.h:5185
Definition parse-tree.h:3529
Definition parse-tree.h:5085
Definition parse-tree.h:3574
Definition parse-tree.h:5473
Definition parse-tree.h:5459
Definition parse-tree.h:5199
Definition parse-tree.h:3693
Definition openmp-utils.h:243
Definition openmp-utils.h:350
Definition openmp-utils.h:195
Definition openmp-utils.h:202