FLANG
IterationSpace.h
1//===-- IterationSpace.h ----------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_LOWER_ITERATIONSPACE_H
14#define FORTRAN_LOWER_ITERATIONSPACE_H
15
16#include "flang/Evaluate/tools.h"
17#include "flang/Lower/StatementContext.h"
18#include "flang/Lower/SymbolMap.h"
19#include "flang/Optimizer/Builder/FIRBuilder.h"
20#include <optional>
21
22namespace llvm {
23class raw_ostream;
24}
25
26namespace Fortran {
27namespace evaluate {
28struct SomeType;
29template <typename>
30class Expr;
31} // namespace evaluate
32
33namespace lower {
34
35using FrontEndExpr = const evaluate::Expr<evaluate::SomeType> *;
36using FrontEndSymbol = const semantics::Symbol *;
37
39
40} // namespace lower
41} // namespace Fortran
42
43namespace Fortran::lower {
44
47class IterationSpace {
48public:
49 IterationSpace() = default;
50
51 template <typename A>
52 explicit IterationSpace(mlir::Value inArg, mlir::Value outRes,
53 llvm::iterator_range<A> range)
54 : inArg{inArg}, outRes{outRes}, indices{range.begin(), range.end()} {}
55
56 explicit IterationSpace(const IterationSpace &from,
58 : inArg(from.inArg), outRes(from.outRes), element(from.element),
59 indices(idxs) {}
60
63 explicit IterationSpace(const IterationSpace &from,
66 : inArg(from.inArg), outRes(from.outRes), element(from.element) {
67 indices.assign(prefix.begin(), prefix.end());
68 indices.append(from.indices.begin(), from.indices.end());
69 indices.append(suffix.begin(), suffix.end());
70 }
71
72 bool empty() const { return indices.empty(); }
73
77 mlir::Value innerArgument() const { return inArg; }
78
82 mlir::Value outerResult() const { return outRes; }
83
86 llvm::SmallVector<mlir::Value> iterVec() const { return indices; }
87
88 mlir::Value iterValue(std::size_t i) const {
89 assert(i < indices.size());
90 return indices[i];
91 }
92
94 void setIndexValue(std::size_t i, mlir::Value v) {
95 assert(i < indices.size());
96 indices[i] = v;
97 }
98
99 void setIndexValues(llvm::ArrayRef<mlir::Value> vals) {
100 indices.assign(vals.begin(), vals.end());
101 }
102
103 void insertIndexValue(std::size_t i, mlir::Value av) {
104 assert(i <= indices.size());
105 indices.insert(indices.begin() + i, av);
106 }
107
111 assert(!fir::getBase(element) && "result element already set");
112 element = ele;
113 }
114
117 mlir::Value getElement() const {
118 assert(fir::getBase(element) && "element must be set");
119 return fir::getBase(element);
120 }
121
123 fir::ExtendedValue elementExv() const { return element; }
124
125 void clearIndices() { indices.clear(); }
126
127private:
128 mlir::Value inArg;
129 mlir::Value outRes;
130 fir::ExtendedValue element;
132};
133
134using GenerateElementalArrayFunc =
135 std::function<fir::ExtendedValue(const IterationSpace &)>;
136
137template <typename A>
139public:
140 bool empty() const { return stack.empty(); }
141
142 void growStack() { stack.push_back(A{}); }
143
145 void bind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) {
146 vmap.insert({e, std::move(fun)});
147 }
148
150 void rebind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) {
151 vmap.erase(e);
152 bind(e, std::move(fun));
153 }
154
156 GenerateElementalArrayFunc getBoundClosure(FrontEndExpr e) const {
157 if (!vmap.count(e))
158 llvm::report_fatal_error(
159 "evaluate::Expr is not in the map of lowered mask expressions");
160 return vmap.lookup(e);
161 }
162
164 bool isLowered(FrontEndExpr e) const { return vmap.count(e); }
165
166 StatementContext &stmtContext() { return stmtCtx; }
167
168protected:
169 void shrinkStack() {
170 assert(!empty());
171 stack.pop_back();
172 if (empty()) {
173 stmtCtx.finalizeAndReset();
174 vmap.clear();
175 }
176 }
177
178 // The stack for the construct information.
179 llvm::SmallVector<A> stack;
180
181 // Map each mask expression back to the temporary holding the initial
182 // evaluation results.
183 llvm::DenseMap<FrontEndExpr, GenerateElementalArrayFunc> vmap;
184
185 // Inflate the statement context for the entire construct. We have to cache
186 // the mask expression results, which are always evaluated first, across the
187 // entire construct.
188 StatementContext stmtCtx;
189};
190
192llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ImplicitIterSpace &);
193
205 : public StackableConstructExpr<llvm::SmallVector<FrontEndExpr>> {
206public:
208 using FrontEndMaskExpr = FrontEndExpr;
209
210 friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
211 const ImplicitIterSpace &);
212
213 LLVM_DUMP_METHOD void dump() const;
214
215 void append(FrontEndMaskExpr e) {
216 assert(!empty());
217 getMasks().back().push_back(e);
218 }
219
220 llvm::SmallVector<FrontEndMaskExpr> getExprs() const {
221 llvm::SmallVector<FrontEndMaskExpr> maskList = getMasks()[0];
222 for (size_t i = 1, d = getMasks().size(); i < d; ++i)
223 maskList.append(getMasks()[i].begin(), getMasks()[i].end());
224 return maskList;
225 }
226
229 void addMaskVariable(FrontEndExpr exp, mlir::Value var, mlir::Value shape,
230 mlir::Value header) {
231 maskVarMap.try_emplace(exp, std::make_tuple(var, shape, header));
232 }
233
236 mlir::Value lookupMaskVariable(FrontEndExpr exp) {
237 return std::get<0>(maskVarMap.lookup(exp));
238 }
239
242 mlir::Value lookupMaskShapeBuffer(FrontEndExpr exp) {
243 return std::get<1>(maskVarMap.lookup(exp));
244 }
245
246 mlir::Value lookupMaskHeader(FrontEndExpr exp) {
247 return std::get<2>(maskVarMap.lookup(exp));
248 }
249
250 // Stack of WHERE constructs, each building a list of mask expressions.
252 return stack;
253 }
255 getMasks() const {
256 return stack;
257 }
258
259 // Cleanup at the end of a WHERE statement or construct.
260 void shrinkStack() {
261 Base::shrinkStack();
262 if (stack.empty())
263 maskVarMap.clear();
264 }
265
266private:
267 llvm::DenseMap<FrontEndExpr,
268 std::tuple<mlir::Value, mlir::Value, mlir::Value>>
269 maskVarMap;
270};
271
273llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ExplicitIterSpace &);
274
278 SymMap &symMap);
279
283 ExplicitIterSpace &esp);
284using ExplicitSpaceArrayBases =
285 std::variant<FrontEndSymbol, const evaluate::Component *,
286 const evaluate::ArrayRef *>;
287
288unsigned getHashValue(const ExplicitSpaceArrayBases &x);
289bool isEqual(const ExplicitSpaceArrayBases &x,
290 const ExplicitSpaceArrayBases &y);
291
292} // namespace Fortran::lower
293
294namespace llvm {
295template <>
296struct DenseMapInfo<Fortran::lower::ExplicitSpaceArrayBases> {
297 static unsigned
298 getHashValue(const Fortran::lower::ExplicitSpaceArrayBases &v) {
299 return Fortran::lower::getHashValue(v);
300 }
301 static bool isEqual(const Fortran::lower::ExplicitSpaceArrayBases &lhs,
302 const Fortran::lower::ExplicitSpaceArrayBases &rhs) {
303 return Fortran::lower::isEqual(lhs, rhs);
304 }
305};
306} // namespace llvm
307
308namespace Fortran::lower {
321public:
322 using IterSpaceDim =
323 std::tuple<FrontEndSymbol, FrontEndExpr, FrontEndExpr, FrontEndExpr>;
324 using ConcurrentSpec =
325 std::pair<llvm::SmallVector<IterSpaceDim>, FrontEndExpr>;
326 using ArrayBases = ExplicitSpaceArrayBases;
327
328 friend void createArrayLoads(AbstractConverter &converter,
329 ExplicitIterSpace &esp, SymMap &symMap);
331 ExplicitIterSpace &esp);
332
336 bool isActive() const { return forallContextOpen != 0; }
337
339 StatementContext &stmtContext() { return stmtCtx; }
340
341 //===--------------------------------------------------------------------===//
342 // Analysis support
343 //===--------------------------------------------------------------------===//
344
346 void pushLevel();
347
349 void popLevel();
350
352 void addSymbol(FrontEndSymbol sym);
353
355 void exprBase(FrontEndExpr x, bool lhs);
356
358 void endAssign();
359
362
363 //===--------------------------------------------------------------------===//
364 // Code gen support
365 //===--------------------------------------------------------------------===//
366
368 void enter() { forallContextOpen++; }
369
371 void leave();
372
373 void pushLoopNest(std::function<void()> lambda) {
374 ccLoopNest.push_back(lambda);
375 }
376
378 mlir::ValueRange getInnerArgs() const { return innerArgs; }
379
382 innerArgs.clear();
383 for (auto &arg : args)
384 innerArgs.push_back(arg);
385 }
386
388 void resetInnerArgs() { innerArgs = initialArgs; }
389
391 void setOuterLoop(fir::DoLoopOp loop) {
392 clearLoops();
393 outerLoop = loop;
394 }
395
397 void setInnerArg(size_t offset, mlir::Value val) {
398 assert(offset < innerArgs.size());
399 innerArgs[offset] = val;
400 }
401
405 for (auto &arg : innerArgs)
406 result.push_back(arg.getType());
407 return result;
408 }
409
412 void bindLoad(ArrayBases base, fir::ArrayLoadOp load) {
413 loadBindings.try_emplace(std::move(base), load);
414 }
415
416 fir::ArrayLoadOp findBinding(const ArrayBases &base) {
417 return loadBindings.lookup(base);
418 }
419
421 std::optional<size_t> findArgPosition(fir::ArrayLoadOp load);
422
423 bool isLHS(fir::ArrayLoadOp load) {
424 return findArgPosition(load).has_value();
425 }
426
429 mlir::Value findArgumentOfLoad(fir::ArrayLoadOp load) {
430 if (auto opt = findArgPosition(load))
431 return innerArgs[*opt];
432 llvm_unreachable("array load argument not found");
433 }
434
435 size_t argPosition(mlir::Value arg) {
436 for (auto i : llvm::enumerate(innerArgs))
437 if (arg == i.value())
438 return i.index();
439 llvm_unreachable("inner argument value was not found");
440 }
441
442 std::optional<fir::ArrayLoadOp> getLhsLoad(size_t i) {
443 assert(i < lhsBases.size());
444 if (lhsBases[counter])
445 return findBinding(*lhsBases[counter]);
446 return std::nullopt;
447 }
448
450 fir::DoLoopOp getOuterLoop() {
451 assert(outerLoop.has_value());
452 return *outerLoop;
453 }
454
456 StatementContext &outermostContext() { return outerContext; }
457
459 void genLoopNest() {
460 for (auto &lambda : ccLoopNest)
461 lambda();
462 }
463
465 void resetBindings() { loadBindings.clear(); }
466
468 std::size_t getCounter() const { return counter; }
469
471 void incrementCounter() { counter++; }
472
473 bool isOutermostForall() const {
474 assert(forallContextOpen);
475 return forallContextOpen == 1;
476 }
477
478 void attachLoopCleanup(std::function<void(fir::FirOpBuilder &builder)> fn) {
479 if (!loopCleanup) {
480 loopCleanup = fn;
481 return;
482 }
483 std::function<void(fir::FirOpBuilder &)> oldFn = *loopCleanup;
484 loopCleanup = [=](fir::FirOpBuilder &builder) {
485 oldFn(builder);
486 fn(builder);
487 };
488 }
489
490 // LLVM standard dump method.
491 LLVM_DUMP_METHOD void dump() const;
492
493 // Pretty-print.
494 friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
495 const ExplicitIterSpace &);
496
498 void finalizeContext() { stmtCtx.finalizeAndReset(); }
499
500 void appendLoops(const llvm::SmallVector<fir::DoLoopOp> &loops) {
501 loopStack.push_back(loops);
502 }
503
504 void clearLoops() { loopStack.clear(); }
505
507 return loopStack;
508 }
509
510private:
512 void conditionalCleanup();
513
514 StatementContext outerContext;
515
516 // A stack of lists of front-end symbols.
517 llvm::SmallVector<llvm::SmallVector<FrontEndSymbol>> symbolStack;
518 llvm::SmallVector<std::optional<ArrayBases>> lhsBases;
519 llvm::SmallVector<llvm::SmallVector<ArrayBases>> rhsBases;
520 llvm::DenseMap<ArrayBases, fir::ArrayLoadOp> loadBindings;
521
522 // Stack of lambdas to create the loop nest.
523 llvm::SmallVector<std::function<void()>> ccLoopNest;
524
525 // Assignment statement context (inside the loop nest).
526 StatementContext stmtCtx;
527 llvm::SmallVector<mlir::Value> innerArgs;
528 llvm::SmallVector<mlir::Value> initialArgs;
529 std::optional<fir::DoLoopOp> outerLoop;
530 llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> loopStack;
531 std::optional<std::function<void(fir::FirOpBuilder &)>> loopCleanup;
532 std::size_t forallContextOpen = 0;
533 std::size_t counter = 0;
534};
535
538template <typename A>
540 const A &exprSyms) {
541 for (const auto &sym : exprSyms)
542 if (llvm::is_contained(ctrlSet, &sym.get()))
543 return true;
544 return false;
545}
546
549template <typename A>
551 const A &subscripts) {
552 for (auto &sub : subscripts) {
553 if (const auto *expr =
554 std::get_if<evaluate::IndirectSubscriptIntegerExpr>(&sub.u))
555 if (symbolSetsIntersect(ctrlSet, evaluate::CollectSymbols(expr->value())))
556 return true;
557 }
558 return false;
559}
560
561} // namespace Fortran::lower
562
563#endif // FORTRAN_LOWER_ITERATIONSPACE_H
Definition common.h:215
Definition AbstractConverter.h:87
Definition IterationSpace.h:320
StatementContext & stmtContext()
Get the statement context.
Definition IterationSpace.h:339
void resetInnerArgs()
Reset the outermost array_load arguments to the loop nest.
Definition IterationSpace.h:388
void popLevel()
Close the construct.
Definition IterationSpace.cpp:279
void finalizeContext()
Finalize the current body statement context.
Definition IterationSpace.h:498
mlir::Value findArgumentOfLoad(fir::ArrayLoadOp load)
Definition IterationSpace.h:429
mlir::ValueRange getInnerArgs() const
Get the inner arguments that correspond to the output arrays.
Definition IterationSpace.h:378
void setInnerArgs(llvm::ArrayRef< mlir::BlockArgument > args)
Set the inner arguments for the next loop level.
Definition IterationSpace.h:381
friend void createArrayMergeStores(AbstractConverter &converter, ExplicitIterSpace &esp)
friend void createArrayLoads(AbstractConverter &converter, ExplicitIterSpace &esp, SymMap &symMap)
void resetBindings()
Clear the array_load bindings.
Definition IterationSpace.h:465
llvm::SmallVector< mlir::Type > innerArgTypes() const
Get the types of the output arrays.
Definition IterationSpace.h:403
void leave()
Leave a FORALL context.
Definition IterationSpace.cpp:239
void enter()
Enter a FORALL context.
Definition IterationSpace.h:368
void genLoopNest()
Generate the explicit loop nest.
Definition IterationSpace.h:459
void incrementCounter()
Increment the counter value to the next assignment statement.
Definition IterationSpace.h:471
std::optional< size_t > findArgPosition(fir::ArrayLoadOp load)
load must be a LHS array_load. Returns std::nullopt on error.
Definition IterationSpace.cpp:300
void bindLoad(ArrayBases base, fir::ArrayLoadOp load)
Definition IterationSpace.h:412
void pushLevel()
Open a new construct. The analysis phase starts here.
Definition IterationSpace.cpp:275
void addSymbol(FrontEndSymbol sym)
Add new concurrent header control variable symbol.
Definition IterationSpace.cpp:245
void setInnerArg(size_t offset, mlir::Value val)
Sets the inner loop argument at position offset to val.
Definition IterationSpace.h:397
fir::DoLoopOp getOuterLoop()
Return the outermost loop in this FORALL nest.
Definition IterationSpace.h:450
void setOuterLoop(fir::DoLoopOp loop)
Capture the current outermost loop.
Definition IterationSpace.h:391
void exprBase(FrontEndExpr x, bool lhs)
Collect array bases from the expression, x.
Definition IterationSpace.cpp:251
llvm::SmallVector< FrontEndSymbol > collectAllSymbols()
Return all the active control variables on the stack.
Definition IterationSpace.cpp:313
void endAssign()
Called at the end of a assignment statement.
Definition IterationSpace.cpp:273
StatementContext & outermostContext()
Return the statement context for the entire, outermost FORALL construct.
Definition IterationSpace.h:456
bool isActive() const
Definition IterationSpace.h:336
std::size_t getCounter() const
Get the current counter value.
Definition IterationSpace.h:468
Definition IterationSpace.h:205
mlir::Value lookupMaskVariable(FrontEndExpr exp)
Definition IterationSpace.h:236
mlir::Value lookupMaskShapeBuffer(FrontEndExpr exp)
Definition IterationSpace.h:242
void addMaskVariable(FrontEndExpr exp, mlir::Value var, mlir::Value shape, mlir::Value header)
Definition IterationSpace.h:229
mlir::Value getElement() const
Definition IterationSpace.h:117
IterationSpace(const IterationSpace &from, llvm::ArrayRef< mlir::Value > prefix, llvm::ArrayRef< mlir::Value > suffix)
Definition IterationSpace.h:63
mlir::Value innerArgument() const
Definition IterationSpace.h:77
fir::ExtendedValue elementExv() const
Get the element as an extended value.
Definition IterationSpace.h:123
mlir::Value outerResult() const
Definition IterationSpace.h:82
llvm::SmallVector< mlir::Value > iterVec() const
Definition IterationSpace.h:86
void setElement(fir::ExtendedValue &&ele)
Definition IterationSpace.h:110
void setIndexValue(std::size_t i, mlir::Value v)
Set (rewrite) the Value at a given index.
Definition IterationSpace.h:94
Definition IterationSpace.h:138
bool isLowered(FrontEndExpr e) const
Has the front-end expression, e, been lowered and bound?
Definition IterationSpace.h:164
void rebind(FrontEndExpr e, GenerateElementalArrayFunc &&fun)
Replace the binding of front-end expression e with a new closure.
Definition IterationSpace.h:150
void bind(FrontEndExpr e, GenerateElementalArrayFunc &&fun)
Bind a front-end expression to a closure.
Definition IterationSpace.h:145
GenerateElementalArrayFunc getBoundClosure(FrontEndExpr e) const
Get the closure bound to the front-end expression, e.
Definition IterationSpace.h:156
Definition StatementContext.h:46
void finalizeAndReset()
Make cleanup calls. Clear the stack top list.
Definition StatementContext.h:90
Definition SymbolMap.h:182
Definition BoxValue.h:478
Definition FIRBuilder.h:59
Definition FIRType.h:103
Definition OpenACC.h:20
Definition call.h:34
Definition ParserActions.h:24
bool symbolSetsIntersect(llvm::ArrayRef< FrontEndSymbol > ctrlSet, const A &exprSyms)
Definition IterationSpace.h:539
void createArrayMergeStores(AbstractConverter &converter, ExplicitIterSpace &esp)
Definition ConvertExpr.cpp:7711
bool symbolsIntersectSubscripts(llvm::ArrayRef< FrontEndSymbol > ctrlSet, const A &subscripts)
Definition IterationSpace.h:550
void createArrayLoads(AbstractConverter &converter, ExplicitIterSpace &esp, SymMap &symMap)
Definition ConvertExpr.cpp:7689
Definition bit-population-count.h:20
mlir::Value getBase(const ExtendedValue &exv)
Definition BoxValue.cpp:21
Definition type.h:418