Skip to content

Commit ce25459

Browse files
[mlir][Conversion] Store const type converter in ConversionPattern
ConversionPatterns do not (and should not) modify the type converter that they are using. * Make `ConversionPattern::typeConverter` const. * Make member functions of the `LLVMTypeConverter` const. * Conversion patterns take a const type converter. * Various helper functions (that are called from patterns) now also take a const type converter. Differential Revision: https://reviews.llvm.org/D157601
1 parent ce16c3c commit ce25459

File tree

35 files changed

+383
-358
lines changed

35 files changed

+383
-358
lines changed

flang/include/flang/Optimizer/CodeGen/TypeConverter.h

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,56 +49,54 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
4949

5050
// i32 is used here because LLVM wants i32 constants when indexing into struct
5151
// types. Indexing into other aggregate types is more flexible.
52-
mlir::Type offsetType();
52+
mlir::Type offsetType() const;
5353

5454
// i64 can be used to index into aggregates like arrays
55-
mlir::Type indexType();
55+
mlir::Type indexType() const;
5656

5757
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
5858
std::optional<mlir::LogicalResult>
5959
convertRecordType(fir::RecordType derived,
6060
llvm::SmallVectorImpl<mlir::Type> &results,
61-
llvm::ArrayRef<mlir::Type> callStack);
61+
llvm::ArrayRef<mlir::Type> callStack) const;
6262

6363
// Is an extended descriptor needed given the element type of a fir.box type ?
6464
// Extended descriptors are required for derived types.
65-
bool requiresExtendedDesc(mlir::Type boxElementType);
65+
bool requiresExtendedDesc(mlir::Type boxElementType) const;
6666

6767
// Magic value to indicate we do not know the rank of an entity, either
6868
// because it is assumed rank or because we have not determined it yet.
6969
static constexpr int unknownRank() { return -1; }
7070

7171
// This corresponds to the descriptor as defined in ISO_Fortran_binding.h and
7272
// the addendum defined in descriptor.h.
73-
mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank());
73+
mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank()) const;
7474

7575
/// Convert fir.box type to the corresponding llvm struct type instead of a
7676
/// pointer to this struct type.
77-
mlir::Type convertBoxTypeAsStruct(BaseBoxType box);
77+
mlir::Type convertBoxTypeAsStruct(BaseBoxType box) const;
7878

7979
// fir.boxproc<any> --> llvm<"{ any*, i8* }">
80-
mlir::Type convertBoxProcType(BoxProcType boxproc);
80+
mlir::Type convertBoxProcType(BoxProcType boxproc) const;
8181

82-
unsigned characterBitsize(fir::CharacterType charTy);
82+
unsigned characterBitsize(fir::CharacterType charTy) const;
8383

8484
// fir.char<k,?> --> llvm<"ix"> where ix is scaled by kind mapping
8585
// fir.char<k,n> --> llvm.array<n x "ix">
86-
mlir::Type convertCharType(fir::CharacterType charTy);
86+
mlir::Type convertCharType(fir::CharacterType charTy) const;
8787

8888
// Use the target specifics to figure out how to map complex to LLVM IR. The
8989
// use of complex values in function signatures is handled before conversion
9090
// to LLVM IR dialect here.
9191
//
9292
// fir.complex<T> | std.complex<T> --> llvm<"{t,t}">
93-
template <typename C>
94-
mlir::Type convertComplexType(C cmplx) {
93+
template <typename C> mlir::Type convertComplexType(C cmplx) const {
9594
LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n');
9695
auto eleTy = cmplx.getElementType();
9796
return convertType(specifics->complexMemoryType(eleTy));
9897
}
9998

100-
template <typename A>
101-
mlir::Type convertPointerLike(A &ty) {
99+
template <typename A> mlir::Type convertPointerLike(A &ty) const {
102100
mlir::Type eleTy = ty.getEleTy();
103101
// A sequence type is a special case. A sequence of runtime size on its
104102
// interior dimensions lowers to a memory reference. In that case, we
@@ -126,27 +124,27 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
126124

127125
// convert a front-end kind value to either a std or LLVM IR dialect type
128126
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
129-
mlir::Type convertRealType(fir::KindTy kind);
127+
mlir::Type convertRealType(fir::KindTy kind) const;
130128

131129
// fir.array<c ... :any> --> llvm<"[...[c x any]]">
132-
mlir::Type convertSequenceType(SequenceType seq);
130+
mlir::Type convertSequenceType(SequenceType seq) const;
133131

134132
// fir.tdesc<any> --> llvm<"i8*">
135133
// TODO: For now use a void*, however pointer identity is not sufficient for
136134
// the f18 object v. class distinction (F2003).
137-
mlir::Type convertTypeDescType(mlir::MLIRContext *ctx);
135+
mlir::Type convertTypeDescType(mlir::MLIRContext *ctx) const;
138136

139-
KindMapping &getKindMap() { return kindMapping; }
137+
const KindMapping &getKindMap() const { return kindMapping; }
140138

141139
// Relay TBAA tag attachment to TBAABuilder.
142140
void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
143141
mlir::Type baseFIRType, mlir::Type accessFIRType,
144-
mlir::LLVM::GEPOp gep);
142+
mlir::LLVM::GEPOp gep) const;
145143

146144
private:
147145
KindMapping kindMapping;
148146
std::unique_ptr<CodeGenSpecifics> specifics;
149-
TBAABuilder tbaaBuilder;
147+
std::unique_ptr<TBAABuilder> tbaaBuilder;
150148
};
151149

152150
} // namespace fir

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ namespace {
117117
template <typename FromOp>
118118
class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
119119
public:
120-
explicit FIROpConversion(fir::LLVMTypeConverter &lowering,
120+
explicit FIROpConversion(const fir::LLVMTypeConverter &lowering,
121121
const fir::FIRToLLVMPassOptions &options)
122122
: mlir::ConvertOpToLLVMPattern<FromOp>(lowering), options(options) {}
123123

@@ -359,8 +359,9 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
359359
return al;
360360
}
361361

362-
fir::LLVMTypeConverter &lowerTy() const {
363-
return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
362+
const fir::LLVMTypeConverter &lowerTy() const {
363+
return *static_cast<const fir::LLVMTypeConverter *>(
364+
this->getTypeConverter());
364365
}
365366

366367
void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
@@ -3191,8 +3192,8 @@ struct SelectCaseOpConversion : public FIROpConversion<fir::SelectCaseOp> {
31913192
};
31923193

31933194
template <typename OP>
3194-
static void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
3195-
typename OP::Adaptor adaptor,
3195+
static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
3196+
OP select, typename OP::Adaptor adaptor,
31963197
mlir::ConversionPatternRewriter &rewriter) {
31973198
unsigned conds = select.getNumConditions();
31983199
auto cases = select.getCases().getValue();
@@ -3461,7 +3462,7 @@ template <typename LLVMOP, typename OPTY>
34613462
static mlir::LLVM::InsertValueOp
34623463
complexSum(OPTY sumop, mlir::ValueRange opnds,
34633464
mlir::ConversionPatternRewriter &rewriter,
3464-
fir::LLVMTypeConverter &lowering) {
3465+
const fir::LLVMTypeConverter &lowering) {
34653466
mlir::Value a = opnds[0];
34663467
mlir::Value b = opnds[1];
34673468
auto loc = sumop.getLoc();
@@ -3610,7 +3611,7 @@ struct NegcOpConversion : public FIROpConversion<fir::NegcOp> {
36103611
/// These operations are normally dead after the pre-codegen pass.
36113612
template <typename FromOp>
36123613
struct MustBeDeadConversion : public FIROpConversion<FromOp> {
3613-
explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering,
3614+
explicit MustBeDeadConversion(const fir::LLVMTypeConverter &lowering,
36143615
const fir::FIRToLLVMPassOptions &options)
36153616
: FIROpConversion<FromOp>(lowering, options) {}
36163617
using OpAdaptor = typename FromOp::Adaptor;

flang/lib/Optimizer/CodeGen/TypeConverter.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)
3737
specifics(CodeGenSpecifics::get(module.getContext(),
3838
getTargetTriple(module),
3939
getKindMapping(module))),
40-
tbaaBuilder(module->getContext(), applyTBAA) {
40+
tbaaBuilder(
41+
std::make_unique<TBAABuilder>(module->getContext(), applyTBAA)) {
4142
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
4243

4344
// Each conversion should return a value of type mlir::Type.
@@ -155,20 +156,19 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)
155156

156157
// i32 is used here because LLVM wants i32 constants when indexing into struct
157158
// types. Indexing into other aggregate types is more flexible.
158-
mlir::Type LLVMTypeConverter::offsetType() {
159+
mlir::Type LLVMTypeConverter::offsetType() const {
159160
return mlir::IntegerType::get(&getContext(), 32);
160161
}
161162

162163
// i64 can be used to index into aggregates like arrays
163-
mlir::Type LLVMTypeConverter::indexType() {
164+
mlir::Type LLVMTypeConverter::indexType() const {
164165
return mlir::IntegerType::get(&getContext(), 64);
165166
}
166167

167168
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
168-
std::optional<mlir::LogicalResult>
169-
LLVMTypeConverter::convertRecordType(fir::RecordType derived,
170-
llvm::SmallVectorImpl<mlir::Type> &results,
171-
llvm::ArrayRef<mlir::Type> callStack) {
169+
std::optional<mlir::LogicalResult> LLVMTypeConverter::convertRecordType(
170+
fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results,
171+
llvm::ArrayRef<mlir::Type> callStack) const {
172172
auto name = derived.getName();
173173
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
174174
if (llvm::count(callStack, derived) > 1) {
@@ -192,14 +192,14 @@ LLVMTypeConverter::convertRecordType(fir::RecordType derived,
192192

193193
// Is an extended descriptor needed given the element type of a fir.box type ?
194194
// Extended descriptors are required for derived types.
195-
bool LLVMTypeConverter::requiresExtendedDesc(mlir::Type boxElementType) {
195+
bool LLVMTypeConverter::requiresExtendedDesc(mlir::Type boxElementType) const {
196196
auto eleTy = fir::unwrapSequenceType(boxElementType);
197197
return eleTy.isa<fir::RecordType>();
198198
}
199199

200200
// This corresponds to the descriptor as defined in ISO_Fortran_binding.h and
201201
// the addendum defined in descriptor.h.
202-
mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) {
202+
mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) const {
203203
// (base_addr*, elem_len, version, rank, type, attribute, f18Addendum, [dim]
204204
llvm::SmallVector<mlir::Type> dataDescFields;
205205
mlir::Type ele = box.getEleTy();
@@ -269,14 +269,14 @@ mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) {
269269

270270
/// Convert fir.box type to the corresponding llvm struct type instead of a
271271
/// pointer to this struct type.
272-
mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box) {
272+
mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box) const {
273273
return convertBoxType(box)
274274
.cast<mlir::LLVM::LLVMPointerType>()
275275
.getElementType();
276276
}
277277

278278
// fir.boxproc<any> --> llvm<"{ any*, i8* }">
279-
mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) {
279+
mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) const {
280280
auto funcTy = convertType(boxproc.getEleTy());
281281
auto i8PtrTy = mlir::LLVM::LLVMPointerType::get(
282282
mlir::IntegerType::get(&getContext(), 8));
@@ -285,13 +285,13 @@ mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) {
285285
/*isPacked=*/false);
286286
}
287287

288-
unsigned LLVMTypeConverter::characterBitsize(fir::CharacterType charTy) {
288+
unsigned LLVMTypeConverter::characterBitsize(fir::CharacterType charTy) const {
289289
return kindMapping.getCharacterBitsize(charTy.getFKind());
290290
}
291291

292292
// fir.char<k,?> --> llvm<"ix"> where ix is scaled by kind mapping
293293
// fir.char<k,n> --> llvm.array<n x "ix">
294-
mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) {
294+
mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) const {
295295
auto iTy = mlir::IntegerType::get(&getContext(), characterBitsize(charTy));
296296
if (charTy.getLen() == fir::CharacterType::unknownLen())
297297
return iTy;
@@ -300,13 +300,13 @@ mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) {
300300

301301
// convert a front-end kind value to either a std or LLVM IR dialect type
302302
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
303-
mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) {
303+
mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) const {
304304
return fir::fromRealTypeID(&getContext(), kindMapping.getRealTypeID(kind),
305305
kind);
306306
}
307307

308308
// fir.array<c ... :any> --> llvm<"[...[c x any]]">
309-
mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) {
309+
mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) const {
310310
auto baseTy = convertType(seq.getEleTy());
311311
if (characterWithDynamicLen(seq.getEleTy()))
312312
return mlir::LLVM::LLVMPointerType::get(baseTy);
@@ -328,7 +328,8 @@ mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) {
328328
// fir.tdesc<any> --> llvm<"i8*">
329329
// TODO: For now use a void*, however pointer identity is not sufficient for
330330
// the f18 object v. class distinction (F2003).
331-
mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) {
331+
mlir::Type
332+
LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) const {
332333
return mlir::LLVM::LLVMPointerType::get(
333334
mlir::IntegerType::get(&getContext(), 8));
334335
}
@@ -337,8 +338,8 @@ mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) {
337338
void LLVMTypeConverter::attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
338339
mlir::Type baseFIRType,
339340
mlir::Type accessFIRType,
340-
mlir::LLVM::GEPOp gep) {
341-
tbaaBuilder.attachTBAATag(op, baseFIRType, accessFIRType, gep);
341+
mlir::LLVM::GEPOp gep) const {
342+
tbaaBuilder->attachTBAATag(op, baseFIRType, accessFIRType, gep);
342343
}
343344

344345
} // namespace fir

0 commit comments

Comments
 (0)