FLANG
openmp-utils.h
1//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Common OpenMP utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
14#define FORTRAN_PARSER_OPENMP_UTILS_H
15
16#include "flang/Common/indirection.h"
17#include "flang/Common/template.h"
18#include "flang/Parser/parse-tree.h"
19#include "llvm/Frontend/OpenMP/OMP.h"
20
21#include <cassert>
22#include <tuple>
23#include <type_traits>
24#include <utility>
25#include <variant>
26#include <vector>
27
28namespace Fortran::parser::omp {
29
30template <typename T> constexpr auto addr_if(std::optional<T> &x) {
31 return x ? &*x : nullptr;
32}
33template <typename T> constexpr auto addr_if(const std::optional<T> &x) {
34 return x ? &*x : nullptr;
35}
36
37namespace detail {
39 static OmpDirectiveName MakeName(CharBlock source = {},
40 llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
42 name.source = source;
43 name.v = id;
44 return name;
45 }
46
47 static OmpDirectiveName GetOmpDirectiveName(const OmpDirectiveName &x) {
48 return x;
49 }
50
51 static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
52 return x.DirName();
53 }
54
55 static OmpDirectiveName GetOmpDirectiveName(const OpenMPSectionConstruct &x) {
56 if (auto &spec{std::get<std::optional<OmpDirectiveSpecification>>(x.t)}) {
57 return spec->DirName();
58 } else {
59 return MakeName({}, llvm::omp::Directive::OMPD_section);
60 }
61 }
62
63 static OmpDirectiveName GetOmpDirectiveName(
65 return x.DirName();
66 }
67
68 template <typename T>
69 static OmpDirectiveName GetOmpDirectiveName(const T &x) {
70 if constexpr (WrapperTrait<T>) {
71 return GetOmpDirectiveName(x.v);
72 } else if constexpr (TupleTrait<T>) {
73 if constexpr (std::is_base_of_v<OmpBlockConstruct, T>) {
74 return std::get<OmpBeginDirective>(x.t).DirName();
75 } else {
76 return GetFromTuple(
77 x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
78 }
79 } else if constexpr (UnionTrait<T>) {
80 return common::visit(
81 [](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
82 } else {
83 return MakeName();
84 }
85 }
86
87 template <typename... Ts, size_t... Is>
88 static OmpDirectiveName GetFromTuple(
89 const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
90 OmpDirectiveName name = MakeName();
91 auto accumulate = [&](const OmpDirectiveName &n) {
92 if (name.v == llvm::omp::Directive::OMPD_unknown) {
93 name = n;
94 } else {
95 assert(
96 n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
97 }
98 };
99 (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
100 return name;
101 }
102
103 template <typename T>
104 static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
105 return GetOmpDirectiveName(x.value());
106 }
107};
108} // namespace detail
109
110template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
111 return detail::DirectiveNameScope::GetOmpDirectiveName(x);
112}
113
115const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
116
117const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
118const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
119
120// Is the template argument "Statement<T>" for some T?
121template <typename T> struct IsStatement {
122 static constexpr bool value{false};
123};
124template <typename T> struct IsStatement<Statement<T>> {
125 static constexpr bool value{true};
126};
127
128std::optional<Label> GetStatementLabel(const ExecutionPartConstruct &x);
129std::optional<Label> GetFinalLabel(const OpenMPConstruct &x);
130
131namespace detail {
132// Clauses with flangClass = "OmpObjectList".
133using MemberObjectListClauses =
134 std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
135 OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
136 OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
137 OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
138
139// Clauses with flangClass = "OmpSomeClause", and OmpObjectList a
140// member of tuple OmpSomeClause::t.
141using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
142 OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
143 OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
144 OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
145 OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
146
147// Does U have WrapperTrait (i.e. has a member 'v'), and if so, is T the
148// type of v?
149template <typename T, typename U, bool IsWrapper> struct WrappedInType {
150 static constexpr bool value{false};
151};
152
153template <typename T, typename U> struct WrappedInType<T, U, true> {
154 static constexpr bool value{std::is_same_v<T, decltype(U::v)>};
155};
156
157// Same as WrappedInType, but with a list of types Us. Satisfied if any
158// type U in Us satisfies WrappedInType<T, U>.
159template <typename...> struct WrappedInTypes;
160
161template <typename T> struct WrappedInTypes<T> {
162 static constexpr bool value{false};
163};
164
165template <typename T, typename U, typename... Us>
166struct WrappedInTypes<T, U, Us...> {
167 static constexpr bool value{WrappedInType<T, U, WrapperTrait<U>>::value ||
168 WrappedInTypes<T, Us...>::value};
169};
170
171// Same as WrappedInTypes, but takes type list in a form of a tuple or
172// a variant.
173template <typename...> struct WrappedInTupleOrVariant {
174 static constexpr bool value{false};
175};
176template <typename T, typename... Us>
177struct WrappedInTupleOrVariant<T, std::tuple<Us...>> {
178 static constexpr bool value{WrappedInTypes<T, Us...>::value};
179};
180template <typename T, typename... Us>
181struct WrappedInTupleOrVariant<T, std::variant<Us...>> {
182 static constexpr bool value{WrappedInTypes<T, Us...>::value};
183};
184template <typename T, typename U>
185constexpr bool WrappedInTupleOrVariantV{WrappedInTupleOrVariant<T, U>::value};
186} // namespace detail
187
188template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
189 using namespace detail;
190 static_assert(std::is_class_v<T>, "Unexpected argument type");
191
192 if constexpr (common::HasMember<T, decltype(OmpClause::u)>) {
193 if constexpr (common::HasMember<T, MemberObjectListClauses>) {
194 return &clause.v;
195 } else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
196 return &std::get<OmpObjectList>(clause.v.t);
197 } else {
198 return nullptr;
199 }
200 } else if constexpr (WrappedInTupleOrVariantV<T, TupleObjectListClauses>) {
201 return &std::get<OmpObjectList>(clause.t);
202 } else if constexpr (WrappedInTupleOrVariantV<T, decltype(OmpClause::u)>) {
203 return nullptr;
204 } else {
205 // The condition should be type-dependent, but it should always be false.
206 static_assert(sizeof(T) < 0 && "Unexpected argument type");
207 }
208}
209
210const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
211const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
212const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);
213
214template <typename T>
215const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
216 for (const OmpArgument &arg : spec.Arguments().v) {
217 if (auto *t{std::get_if<T>(&arg.u)}) {
218 return t;
219 }
220 }
221 return nullptr;
222}
223
224const BlockConstruct *GetFortranBlockConstruct(
225 const ExecutionPartConstruct &epc);
226const Block &GetInnermostExecPart(const Block &block);
227bool IsStrictlyStructuredBlock(const Block &block);
228
229const OmpCombinerExpression *GetCombinerExpr(
230 const OmpReductionSpecifier &rspec);
231const OmpInitializerExpression *GetInitializerExpr(const OmpClause &init);
232
234 std::vector<const OmpAllocateDirective *> dirs;
235 const ExecutionPartConstruct *body{nullptr};
236};
237
238OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x);
239
240namespace detail {
241template <bool IsConst, typename T> struct ConstIf {
242 using type = std::conditional_t<IsConst, std::add_const_t<T>, T>;
243};
244
245template <bool IsConst, typename T>
246using ConstIfT = typename ConstIf<IsConst, T>::type;
247} // namespace detail
248
249template <bool IsConst> struct LoopRange {
250 using QualBlock = detail::ConstIfT<IsConst, Block>;
251 using QualReference = decltype(std::declval<QualBlock>().front());
252 using QualPointer = std::remove_reference_t<QualReference> *;
253
254 LoopRange(QualBlock &x) { Initialize(x); }
255 LoopRange(QualReference x);
256
257 LoopRange(detail::ConstIfT<IsConst, OpenMPLoopConstruct> &x)
258 : LoopRange(std::get<Block>(x.t)) {}
259 LoopRange(detail::ConstIfT<IsConst, DoConstruct> &x)
260 : LoopRange(std::get<Block>(x.t)) {}
261
262 size_t size() const { return items.size(); }
263 bool empty() const { return items.size() == 0; }
264
265 struct iterator;
266
267 iterator begin();
268 iterator end();
269
270private:
271 void Initialize(QualBlock &body);
272
273 std::vector<QualPointer> items;
274};
275
276template <typename T> LoopRange(T &x) -> LoopRange<std::is_const_v<T>>;
277
278template <bool IsConst> struct LoopRange<IsConst>::iterator {
279 QualReference operator*() { return **at; }
280
281 bool operator==(const iterator &other) const { return at == other.at; }
282 bool operator!=(const iterator &other) const { return at != other.at; }
283
284 iterator &operator++() {
285 ++at;
286 return *this;
287 }
288 iterator &operator--() {
289 --at;
290 return *this;
291 }
292 iterator operator++(int);
293 iterator operator--(int);
294
295private:
296 friend struct LoopRange;
297 typename decltype(LoopRange::items)::iterator at;
298};
299
300template <bool IsConst> inline auto LoopRange<IsConst>::begin() -> iterator {
301 iterator x;
302 x.at = items.begin();
303 return x;
304}
305
306template <bool IsConst> inline auto LoopRange<IsConst>::end() -> iterator {
307 iterator x;
308 x.at = items.end();
309 return x;
310}
311
312using ConstLoopRange = LoopRange<true>;
313
314extern template struct LoopRange<true>;
315extern template struct LoopRange<false>;
316
317} // namespace Fortran::parser::omp
318
319#endif // FORTRAN_PARSER_OPENMP_UTILS_H
Definition indirection.h:31
Definition char-block.h:28
Definition parse-tree.h:437
Definition parse-tree.h:2348
Definition parse-tree.h:554
Definition parse-tree.h:5283
Definition parse-tree.h:5420
Definition parse-tree.h:5148
Definition parse-tree.h:3516
Definition parse-tree.h:3561
Definition parse-tree.h:5455
Definition parse-tree.h:5429
Definition parse-tree.h:5162
Definition parse-tree.h:359
Definition openmp-utils.h:121
Definition openmp-utils.h:278
Definition openmp-utils.h:249
Definition openmp-utils.h:233
Definition openmp-utils.h:241
Definition openmp-utils.h:149