FLANG
FIRToMemRefTypeConverter.h
1//===---- FIRToMemRefTypeConverter.h - FIR type conversion to MemRef ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines `FIRToMemRefTypeConverter`, a helper used by the
10// FIR-to-MemRef conversion pass to convert FIR types (scalars, arrays,
11// descriptors) into MemRef types suitable for the MemRef dialect.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_FIRTOMEMREFTYPECONVERTER_H
16#define FORTRAN_OPTIMIZER_TRANSFORMS_FIRTOMEMREFTYPECONVERTER_H
17
18#include "flang/Optimizer/Dialect/FIRDialect.h"
19#include "flang/Optimizer/Dialect/FIROps.h"
20#include "flang/Optimizer/Dialect/FIRType.h"
21#include "flang/Optimizer/Dialect/Support/FIRContext.h"
22#include "flang/Optimizer/Dialect/Support/KindMapping.h"
23#include "mlir/IR/BuiltinAttributes.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/Transforms/DialectConversion.h"
26
27namespace fir {
28
29class FIRToMemRefTypeConverter : public mlir::TypeConverter {
30private:
31 KindMapping kindMapping;
32 bool convertComplexTypes = false;
33 bool convertScalarTypesOnly = false;
34
35public:
36 explicit FIRToMemRefTypeConverter(mlir::ModuleOp mod)
37 : kindMapping(fir::getKindMapping(mod)) {
38 addConversion([](mlir::Type type) { return type; });
39
40 addConversion([&](fir::LogicalType type) -> mlir::Type {
41 return mlir::IntegerType::get(
42 type.getContext(), kindMapping.getLogicalBitsize(type.getFKind()));
43 });
44
45 addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type type,
46 mlir::ValueRange inputs,
47 mlir::Location loc) -> mlir::Value {
48 assert(!inputs.empty() && "expected a single input for materialization");
49 builder.setInsertionPointAfter(inputs[0].getDefiningOp());
50 return fir::ConvertOp::create(builder, loc, type, inputs[0]);
51 });
52
53 addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type type,
54 mlir::ValueRange inputs,
55 mlir::Location loc) -> mlir::Value {
56 return fir::ConvertOp::create(builder, loc, type, inputs[0]);
57 });
58 }
59
61 void setConvertComplexTypes(bool value) { convertComplexTypes = value; }
62
64 void setConvertScalarTypesOnly(bool value) { convertScalarTypesOnly = value; }
65
68 bool convertibleMemrefType(mlir::Type ty) {
69 if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(ty))
70 return convertibleMemrefType(refTy.getElementType());
71 else if (auto pointerTy = mlir::dyn_cast<fir::PointerType>(ty))
72 return convertibleMemrefType(pointerTy.getElementType());
73 else if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
74 return convertibleMemrefType(heapTy.getElementType());
75 else if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
76 return convertibleMemrefType(seqTy.getElementType());
77 else if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty))
78 return convertibleMemrefType(boxTy.getElementType());
79
81 bool result = convertibleType(ty);
83 return result;
84 }
85
88 bool isEmptyArray(mlir::Type ty) const {
89 if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(ty))
90 return isEmptyArray(refTy.getElementType());
91 else if (auto pointerTy = mlir::dyn_cast<fir::PointerType>(ty))
92 return isEmptyArray(pointerTy.getElementType());
93 else if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
94 return isEmptyArray(heapTy.getElementType());
95 else if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
96 llvm::ArrayRef<int64_t> firShape = seqTy.getShape();
97 for (auto shape : firShape)
98 if (shape == 0)
99 return true;
100 return false;
101 }
102 return false;
103 }
104
107 bool convertibleType(mlir::Type type) const {
108 if (!convertScalarTypesOnly) {
109 if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(type)) {
110 auto elTy = refTy.getElementType();
111 if (mlir::isa<fir::SequenceType>(elTy))
112 return false;
113 return convertibleType(elTy);
114 }
115
116 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(type))
117 return convertibleType(seqTy.getElementType());
118 }
119
120 if (fir::isa_fir_type(type)) {
121 if (mlir::isa<fir::LogicalType>(type))
122 return true;
123 return false;
124 }
125
126 if (type.isUnsignedInteger())
127 return false;
128
129 if (mlir::isa<mlir::ComplexType>(type))
130 return convertComplexTypes;
131
132 if (mlir::isa<mlir::FunctionType>(type))
133 return false;
134
135 if (mlir::isa<mlir::TupleType>(type))
136 return false;
137
138 return true;
139 }
140
142 mlir::MemRefType convertMemrefType(mlir::Type firTy) const {
143 auto convertBaseType = [&](mlir::Type firTy) -> mlir::MemRefType {
144 if (auto charTy = mlir::dyn_cast<fir::CharacterType>(firTy)) {
145 unsigned kind = charTy.getFKind();
146 unsigned bitWidth = kindMapping.getCharacterBitsize(kind);
147 mlir::Type elTy = mlir::IntegerType::get(charTy.getContext(), bitWidth);
148
149 if (charTy.hasConstantLen() && charTy.getLen() == 1) {
150 return mlir::MemRefType::get({}, elTy);
151 } else if (charTy.hasConstantLen()) {
152 int64_t len = charTy.getLen();
153 return mlir::MemRefType::get({len}, elTy);
154 } else {
155 return mlir::MemRefType::get({mlir::ShapedType::kDynamic}, elTy);
156 }
157 }
158
159 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(firTy)) {
160 auto elTy = seqTy.getElementType();
161 mlir::Type ty = convertType(elTy);
162
163 llvm::ArrayRef<int64_t> firShape = seqTy.getShape();
165 for (auto it = firShape.rbegin(); it != firShape.rend(); ++it)
166 shape.push_back(*it);
167
168 assert(mlir::BaseMemRefType::isValidElementType(ty) &&
169 "got invalid memref element type from array fir type");
170 return mlir::MemRefType::get(shape, ty);
171 }
172
173 mlir::Type ty = convertType(firTy);
174 assert(mlir::BaseMemRefType::isValidElementType(ty) &&
175 "got invalid memref element type from scalar fir type");
176 return mlir::MemRefType::get({}, ty);
177 };
178
179 if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(firTy))
180 return convertBaseType(refTy.getElementType());
181
182 if (auto pointerTy = mlir::dyn_cast<fir::PointerType>(firTy))
183 return convertBaseType(pointerTy.getElementType());
184
185 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(firTy))
186 return convertBaseType(heapTy.getElementType());
187
188 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(firTy)) {
189 auto elTy = boxTy.getElementType();
190
191 auto memRefTy = convertMemrefType(elTy);
192 mlir::MemRefType dynTy = mlir::MemRefType::Builder(memRefTy).setLayout(
193 mlir::StridedLayoutAttr::get(
194 memRefTy.getContext(), mlir::ShapedType::kDynamic,
195 llvm::SmallVector<int64_t>(memRefTy.getRank(),
196 mlir::ShapedType::kDynamic)));
197 return dynTy;
198 }
199
200 return convertBaseType(firTy);
201 }
202};
203
204} // namespace fir
205
206#endif // FORTRAN_OPTIMIZER_TRANSFORMS_FIRTOMEMREFTYPECONVERTER_H
mlir::MemRefType convertMemrefType(mlir::Type firTy) const
Convert a FIR element / aggregate type to a MemRef descriptor type.
Definition FIRToMemRefTypeConverter.h:142
bool isEmptyArray(mlir::Type ty) const
Definition FIRToMemRefTypeConverter.h:88
bool convertibleMemrefType(mlir::Type ty)
Definition FIRToMemRefTypeConverter.h:68
void setConvertComplexTypes(bool value)
Control whether complex types are considered convertible.
Definition FIRToMemRefTypeConverter.h:61
void setConvertScalarTypesOnly(bool value)
Control whether only scalar types are considered during convertibleType.
Definition FIRToMemRefTypeConverter.h:64
bool convertibleType(mlir::Type type) const
Definition FIRToMemRefTypeConverter.h:107
Definition KindMapping.h:48
Definition FIRType.h:92
Definition OpenACC.h:20
Definition AbstractConverter.h:37
KindMapping getKindMapping(mlir::ModuleOp mod)
Definition FIRContext.cpp:43
bool isa_fir_type(mlir::Type t)
Is t any of the FIR dialect types?
Definition FIRType.cpp:207