13#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
14#define FORTRAN_PARSER_OPENMP_UTILS_H
16#include "flang/Common/indirection.h"
17#include "flang/Common/template.h"
18#include "flang/Parser/parse-tree.h"
19#include "llvm/ADT/iterator_range.h"
20#include "llvm/Frontend/OpenMP/OMP.h"
30namespace Fortran::parser::omp {
32template <
typename T>
constexpr auto addr_if(std::optional<T> &x) {
33 return x ? &*x :
nullptr;
35template <
typename T>
constexpr auto addr_if(
const std::optional<T> &x) {
36 return x ? &*x :
nullptr;
42 template <
typename T>
static const ODS &GetODS(
const T &x) {
44 std::is_base_of_v<OmpBlockConstruct, T> ||
45 std::is_same_v<OpenMPLoopConstruct, T> ||
46 std::is_same_v<OpenMPSectionsConstruct, T>) {
48 }
else if constexpr (WrapperTrait<T>) {
50 }
else if constexpr (UnionTrait<T>) {
52 [](
auto &&s) ->
decltype(
auto) {
return GetODS(s); }, x.u);
54 static_assert(std::is_same_v<OpenMPSectionConstruct, T>);
55 llvm_unreachable(
"This function does not work for SECTION");
58 static inline const ODS &GetODS(
const ODS &x) {
return x; }
70 llvm::omp::Directive
id = llvm::omp::Directive::OMPD_unknown) {
86 if (
auto &spec{std::get<std::optional<OmpDirectiveSpecification>>(x.t)}) {
87 return spec->DirName();
89 return MakeName({}, llvm::omp::Directive::OMPD_section);
100 if constexpr (WrapperTrait<T>) {
101 return GetOmpDirectiveName(x.v);
102 }
else if constexpr (TupleTrait<T>) {
103 if constexpr (std::is_base_of_v<OmpBlockConstruct, T>) {
104 return std::get<OmpBeginDirective>(x.t).DirName();
107 x.t, std::make_index_sequence<std::tuple_size_v<
decltype(x.t)>>{});
109 }
else if constexpr (UnionTrait<T>) {
110 return common::visit(
111 [](
auto &&s) {
return GetOmpDirectiveName(s); }, x.u);
117 template <
typename... Ts,
size_t... Is>
119 const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
122 if (name.v == llvm::omp::Directive::OMPD_unknown) {
126 n.v == llvm::omp::Directive::OMPD_unknown &&
"Conflicting names");
129 (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
133 template <
typename T>
135 return GetOmpDirectiveName(x.value());
141 return detail::DirectiveNameScope::GetOmpDirectiveName(x);
144std::string GetUpperName(llvm::omp::Clause
id,
unsigned version);
145std::string GetUpperName(llvm::omp::Directive
id,
unsigned version);
155using MemberObjectListClauses =
156 std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
157 OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
158 OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
159 OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
163using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
164 OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
165 OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
166 OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
167 OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
172 static constexpr bool value{
false};
176 static constexpr bool value{std::is_same_v<T,
decltype(U::v)>};
184 static constexpr bool value{
false};
187template <
typename T,
typename U,
typename... Us>
196 static constexpr bool value{
false};
198template <
typename T,
typename... Us>
202template <
typename T,
typename... Us>
206template <
typename T,
typename U>
207constexpr bool WrappedInTupleOrVariantV{WrappedInTupleOrVariant<T, U>::value};
210template <
typename T>
const OmpObjectList *GetOmpObjectList(
const T &clause) {
211 using namespace detail;
212 static_assert(std::is_class_v<T>,
"Unexpected argument type");
214 if constexpr (common::HasMember<T,
decltype(OmpClause::u)>) {
215 if constexpr (common::HasMember<T, MemberObjectListClauses>) {
217 }
else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
218 return &std::get<OmpObjectList>(clause.v.t);
222 }
else if constexpr (WrappedInTupleOrVariantV<T, TupleObjectListClauses>) {
223 return &std::get<OmpObjectList>(clause.t);
224 }
else if constexpr (WrappedInTupleOrVariantV<T,
decltype(OmpClause::u)>) {
228 static_assert(
sizeof(T) < 0 &&
"Unexpected argument type");
232const OmpObjectList *GetOmpObjectList(
const OmpClause &clause);
233const OmpObjectList *GetOmpObjectList(
const OmpClause::Depend &clause);
234const OmpObjectList *GetOmpObjectList(
const OmpDependClause::TaskDep &x);
237const T *GetFirstArgument(
const OmpDirectiveSpecification &spec) {
238 for (
const OmpArgument &arg : spec.Arguments().v) {
239 if (
auto *t{std::get_if<T>(&arg.u)}) {
246const OmpClause *FindClause(
247 const OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId);
249const BlockConstruct *GetFortranBlockConstruct(
250 const ExecutionPartConstruct &epc);
251const Block &GetInnermostExecPart(
const Block &block);
252bool IsStrictlyStructuredBlock(
const Block &block);
254const OmpCombinerExpression *GetCombinerExpr(
const OmpReductionSpecifier &x);
255const OmpCombinerExpression *GetCombinerExpr(
const OmpClause &x);
256const OmpInitializerExpression *GetInitializerExpr(
const OmpClause &x);
259 std::vector<const OmpAllocateDirective *> dirs;
265template <
typename R,
typename =
void,
typename =
void>
struct is_range {
266 static constexpr bool value{
false};
271 std::void_t<decltype(std::declval<R>().begin())>,
272 std::void_t<decltype(std::declval<R>().end())>> {
273 static constexpr bool value{
true};
276template <
typename R>
constexpr bool is_range_v = is_range<R>::value;
306struct ExecutionPartIterator {
313 using IteratorType = Block::const_iterator;
314 using IteratorRange = llvm::iterator_range<IteratorType>;
318 struct IteratorGauge :
public IteratorRange {
319 IteratorGauge(IteratorType b, IteratorType e)
320 : IteratorRange(b, e), at(b) {}
321 IteratorGauge(IteratorRange r) : IteratorRange(r), at(r.begin()) {}
323 bool atEnd()
const {
return at == end(); }
329 : location(b, e), owner(c) {}
330 template <
typename R>
332 : location(r), owner(c) {}
333 Construct(
const Construct &c) =
default;
341 ExecutionPartIterator() =
default;
343 ExecutionPartIterator(IteratorType b, IteratorType e, Step s = Step::Default,
346 stack_.emplace_back(b, e, c);
349 template <
typename R,
typename = std::enable_if_t<is_range_v<R>>>
364 bool valid()
const {
return !stack_.empty(); }
366 const std::vector<Construct> &stack()
const {
return stack_; }
367 decltype(
auto)
operator*()
const {
return *at(); }
368 bool operator==(
const ExecutionPartIterator &other)
const {
369 if (valid() != other.valid()) {
374 stack_.back().location.at == other.stack_.back().location.at;
376 bool operator!=(
const ExecutionPartIterator &other)
const {
377 return !(*
this == other);
380 ExecutionPartIterator &operator++() {
381 if (stepping_ == Step::Into) {
384 assert(stepping_ == Step::Over &&
"Unexpected stepping");
390 ExecutionPartIterator operator++(
int) {
391 ExecutionPartIterator copy{*
this};
396 using difference_type = IteratorType::difference_type;
397 using value_type = IteratorType::value_type;
398 using reference = IteratorType::reference;
399 using pointer = IteratorType::pointer;
400 using iterator_category = std::forward_iterator_tag;
403 IteratorType at()
const {
return stack_.back().location.at; };
409 const Step stepping_ = Step::Default;
410 std::vector<Construct> stack_;
413template <
typename Iterator = ExecutionPartIterator>
struct ExecutionPartRange {
414 using Step =
typename Iterator::Step;
416 ExecutionPartRange(Block::const_iterator begin, Block::const_iterator end,
417 Step stepping = Step::Default,
419 : begin_(begin, end, stepping, owner), end_() {}
420 template <
typename R,
typename = std::enable_if_t<is_range_v<R>>>
421 ExecutionPartRange(
const R &range, Step stepping = Step::Default,
423 : ExecutionPartRange(range.begin(), range.end(), stepping, owner) {}
425 Iterator begin()
const {
return begin_; }
426 Iterator end()
const {
return end_; }
429 Iterator begin_, end_;
432struct LoopNestIterator :
public ExecutionPartIterator {
433 LoopNestIterator() =
default;
435 LoopNestIterator(IteratorType b, IteratorType e, Step s = Step::Default,
437 : ExecutionPartIterator(b, e, s, c) {
440 template <
typename R,
typename = std::enable_if_t<is_range_v<R>>>
441 LoopNestIterator(
const R &range, Step stepping = Step::Default,
443 : LoopNestIterator(range.begin(), range.end(), stepping, construct) {}
445 LoopNestIterator &operator++() {
446 ExecutionPartIterator::operator++();
451 LoopNestIterator operator++(
int) {
452 LoopNestIterator copy{*
this};
461 while (valid() && !isLoop(**
this)) {
462 ExecutionPartIterator::operator++();
Definition indirection.h:31
Definition char-block.h:28
Definition parse-tree.h:437
Definition parse-tree.h:2323
Definition parse-tree.h:554
Definition parse-tree.h:5278
Definition parse-tree.h:5415
Definition parse-tree.h:5143
Definition parse-tree.h:3492
Definition parse-tree.h:5048
Definition parse-tree.h:3537
Definition parse-tree.h:5450
Definition parse-tree.h:5282
Definition parse-tree.h:5424
Definition parse-tree.h:5157
Definition openmp-utils.h:318
Definition openmp-utils.h:306
Definition openmp-utils.h:413
Definition openmp-utils.h:258
Definition openmp-utils.h:68
Definition openmp-utils.h:40
Definition openmp-utils.h:195
Definition openmp-utils.h:171
Definition openmp-utils.h:181
Definition openmp-utils.h:265