Skip to content

Commit 6751be6

Browse files
committed
Structs and Headers flattening pass
Signed-off-by: Pietro Ghiglio <pghiglio@accesssoftek.com>
1 parent 64525cc commit 6751be6

File tree

6 files changed

+617
-0
lines changed

6 files changed

+617
-0
lines changed

include/p4mlir/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ std::unique_ptr<mlir::Pass> createInlineParsersPass();
4040
std::unique_ptr<mlir::Pass> createInlineControlsPass();
4141
std::unique_ptr<mlir::Pass> createExpandEmitPass();
4242
std::unique_ptr<mlir::Pass> createSymbolDCEPass();
43+
std::unique_ptr<mlir::Pass> createFlattenStructsPass();
4344

4445
#define GEN_PASS_REGISTRATION
4546
#include "p4mlir/Transforms/Passes.h.inc"

include/p4mlir/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,4 +344,19 @@ def SymbolDCE : Pass<"p4hir-symbol-dce"> {
344344
];
345345
}
346346

347+
//===----------------------------------------------------------------------===//
348+
// FlattenStructs
349+
//===----------------------------------------------------------------------===//
350+
351+
def FlattenStructs : Pass<"p4hir-flatten-structs", "mlir::ModuleOp"> {
352+
let summary = "Flattens Structs";
353+
let description = [{
354+
This pass performs struct and header flattening, ensuring that all the struct types in
355+
the module contain only headers and scalars, and that headers contain only scalars.
356+
}];
357+
358+
let constructor = "P4MLIR::createFlattenStructsPass()";
359+
let dependentDialects = ["P4MLIR::P4HIR::P4HIRDialect"];
360+
}
361+
347362
#endif // P4MLIR_TRANSFORMS_PASSES_TD

lib/Dialect/BMv2IR/Pipelines/BMv2Pipeline.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ void P4::P4MLIR::buildBMv2Pipeline(OpPassManager &pm, const BMv2PipelineOpts &op
1414
pm.addPass(createEnumEliminationPass());
1515
pm.addPass(createSerEnumEliminationPass());
1616
pm.addPass(createRemoveAliasesPass());
17+
pm.addPass(createFlattenStructsPass());
1718

1819
// TODO: eliminate switches
1920

lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_mlir_dialect_library(P4MLIRTransforms
1212
TupleToStruct.cpp
1313
ExpandEmit.cpp
1414
SymbolDCE.cpp
15+
FlattenStructs.cpp
1516

1617
ADDITIONAL_HEADER_DIRS
1718
${PROJECT_SOURCE_DIR}/include/p4mlir/Transforms

lib/Transforms/FlattenStructs.cpp

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
#include <optional>
2+
3+
#include "llvm/ADT/STLExtras.h"
4+
#include "llvm/ADT/STLFunctionalExtras.h"
5+
#include "llvm/ADT/StringExtras.h"
6+
#include "llvm/ADT/TypeSwitch.h"
7+
#include "llvm/Support/LogicalResult.h"
8+
#include "mlir/IR/BuiltinAttributes.h"
9+
#include "mlir/IR/OperationSupport.h"
10+
#include "mlir/IR/PatternMatch.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Transforms/DialectConversion.h"
13+
#include "p4mlir/Conversion/ConversionPatterns.h"
14+
#include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h"
15+
#include "p4mlir/Dialect/P4HIR/P4HIR_Ops.h"
16+
#include "p4mlir/Transforms/Passes.h"
17+
18+
using namespace mlir;
19+
20+
namespace P4::P4MLIR {
21+
#define GEN_PASS_DEF_FLATTENSTRUCTS
22+
#include "p4mlir/Transforms/Passes.cpp.inc"
23+
namespace {
24+
25+
// This TypeConverter flattens structs and headers, using `_` as a seperator for the new field
26+
// names, e.g. From: !ingress_metadata_t = !p4hir.struct<"ingress_metadata_t", vrf: !b12i, bd:
27+
// !b16i, nexthop_index: !b16i> !metadata = !p4hir.struct<"metadata", ingress_metadata:
28+
// !ingress_metadata_t> To: !metadata = !p4hir.struct<"metadata", ingress_metadata_vrf: !b12i,
29+
// ingress_metadata_bd: !b16i, ingress_metadata_nexthop_index: !b16i> Structs are flattened into
30+
// headers but headers aren't flattened into structs
31+
class StructFlatteningTypeConverter : public P4HIRTypeConverter {
32+
public:
33+
StructFlatteningTypeConverter() {
34+
addConversion([](P4HIR::StructLikeTypeInterface type) -> std::optional<Type> {
35+
return flattenStructLikeType(type);
36+
});
37+
}
38+
39+
static bool isHeader(Type type) {
40+
return isa<P4HIR::HeaderType, P4HIR::HeaderUnionType, P4HIR::HeaderStackType>(type);
41+
}
42+
43+
static bool isScalar(Type type) {
44+
return isa<P4HIR::BitsType, P4HIR::VarBitsType, P4HIR::BoolType, P4HIR::ErrorType,
45+
P4HIR::EnumType>(type);
46+
}
47+
48+
private:
49+
static P4HIR::StructLikeTypeInterface flattenStructLikeType(
50+
P4HIR::StructLikeTypeInterface structLikeType) {
51+
if (auto headerStack = dyn_cast<P4HIR::HeaderStackType>(structLikeType)) {
52+
auto flattened = flattenStructLikeType(headerStack.getArrayElementType());
53+
return P4HIR::HeaderStackType::get(headerStack.getContext(), headerStack.getArraySize(),
54+
flattened);
55+
}
56+
57+
SmallVector<P4HIR::FieldInfo> flattenedFields;
58+
59+
for (auto [fieldName, fieldType, annotations] : structLikeType.getFields()) {
60+
if (auto nestedStruct = dyn_cast<P4HIR::StructType>(fieldType)) {
61+
auto flattened = flattenStructLikeType(nestedStruct);
62+
63+
for (auto [nestedName, nestedType, annotations] : flattened.getFields()) {
64+
std::string newName =
65+
(fieldName.getValue() + "_" + nestedName.getValue()).str();
66+
auto newNameAttr = StringAttr::get(structLikeType.getContext(), newName);
67+
flattenedFields.push_back({newNameAttr, nestedType, annotations});
68+
}
69+
} else if (isHeader(fieldType)) {
70+
auto flattened =
71+
flattenStructLikeType(cast<P4HIR::StructLikeTypeInterface>(fieldType));
72+
flattenedFields.push_back({fieldName, cast<Type>(flattened), annotations});
73+
} else if (isScalar(fieldType)) {
74+
flattenedFields.push_back({fieldName, fieldType, annotations});
75+
} else if (isa<P4HIR::ValidBitType>(fieldType)) {
76+
continue;
77+
} else {
78+
llvm::errs() << "Unexpected type " << fieldType << "\n";
79+
llvm_unreachable("Unexpected field type during flattening");
80+
}
81+
}
82+
83+
return llvm::TypeSwitch<P4HIR::StructLikeTypeInterface, P4HIR::StructLikeTypeInterface>(
84+
structLikeType)
85+
.Case([&](P4HIR::StructType structType) {
86+
return P4HIR::StructType::get(structType.getContext(), structType.getName(),
87+
flattenedFields, structType.getAnnotations());
88+
})
89+
.Case([&](P4HIR::HeaderUnionType headerUnionType) {
90+
return P4HIR::HeaderUnionType::get(headerUnionType.getContext(),
91+
headerUnionType.getName(), flattenedFields,
92+
headerUnionType.getAnnotations());
93+
})
94+
.Case([&](P4HIR::HeaderType headerType) {
95+
return P4HIR::HeaderType::get(headerType.getContext(), headerType.getName(),
96+
flattenedFields, headerType.getAnnotations());
97+
});
98+
}
99+
};
100+
101+
// This pattern handles StructFieldRefOps and StructExtractOps, it traverses the tree of
102+
// struct accesses, constructs the new field names as it goes (using `_` as separator) and
103+
// replaces the "leaf" accesses with new ones that use the new field names.
104+
template <typename OpTy>
105+
class FlattenStructAccess : public OpConversionPattern<OpTy> {
106+
public:
107+
using OpConversionPattern<OpTy>::OpConversionPattern;
108+
using AdaptorTy = typename OpConversionPattern<OpTy>::OpAdaptor;
109+
110+
LogicalResult matchAndRewrite(OpTy op, AdaptorTy adaptor,
111+
ConversionPatternRewriter &rewriter) const override {
112+
SmallVector<Operation *> eraseList;
113+
DenseMap<Operation *, Value> replacements;
114+
115+
if (failed(processStructAccessTree(op.getOperation(), adaptor.getInput(), "", eraseList,
116+
replacements, rewriter)))
117+
return failure();
118+
119+
for (auto [oldOp, newValue] : replacements) {
120+
rewriter.replaceOp(oldOp, newValue);
121+
}
122+
123+
for (auto *opToErase : eraseList) {
124+
rewriter.eraseOp(opToErase);
125+
}
126+
127+
return success();
128+
}
129+
130+
private:
131+
LogicalResult processStructAccessTree(Operation *op, Value convertedInput,
132+
StringRef parentFieldPath,
133+
SmallVector<Operation *> &eraseList,
134+
DenseMap<Operation *, Value> &replacements,
135+
ConversionPatternRewriter &rewriter) const {
136+
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
137+
.Case([&](P4HIR::StructFieldRefOp fieldRefOp) {
138+
return processStructAccessOp(fieldRefOp, convertedInput, parentFieldPath, eraseList,
139+
replacements, rewriter);
140+
})
141+
.Case([&](P4HIR::StructExtractOp extractOp) {
142+
return processStructAccessOp(extractOp, convertedInput, parentFieldPath, eraseList,
143+
replacements, rewriter);
144+
})
145+
.Case([&](P4HIR::ReadOp readOp) {
146+
return processRead(readOp, convertedInput, parentFieldPath, eraseList, replacements,
147+
rewriter);
148+
})
149+
.Default([](Operation *op) {
150+
return op->emitError("Unexpected operation in struct access tree");
151+
});
152+
}
153+
154+
// Helper function for StructFieldRefOp and StructExtractOp
155+
template <typename StructAccessOpTy>
156+
LogicalResult processStructAccessOp(StructAccessOpTy op, Value convertedInput,
157+
StringRef parentFieldPath,
158+
SmallVector<Operation *> &eraseList,
159+
DenseMap<Operation *, Value> &replacements,
160+
ConversionPatternRewriter &rewriter) const {
161+
std::string currentFieldPath = parentFieldPath.empty()
162+
? op.getFieldName().str()
163+
: (parentFieldPath + "_" + op.getFieldName()).str();
164+
165+
auto resultTy = this->getTypeConverter()->convertType(op.getResult().getType());
166+
if (!resultTy) return op.emitError("Unable to convert result type");
167+
168+
auto isPredicateOrRefToPredicate = [](Type ty,
169+
llvm::function_ref<bool(Type)> pred) -> bool {
170+
if (auto refTy = dyn_cast<P4HIR::ReferenceType>(ty)) return pred(refTy.getObjectType());
171+
return pred(ty);
172+
};
173+
174+
bool isLeaf =
175+
isPredicateOrRefToPredicate(resultTy, StructFlatteningTypeConverter::isHeader) ||
176+
isPredicateOrRefToPredicate(resultTy, StructFlatteningTypeConverter::isScalar) ||
177+
isPredicateOrRefToPredicate(resultTy,
178+
[](Type ty) { return isa<P4HIR::ValidBitType>(ty); });
179+
180+
// If this is the leaf of a chain of struct accesses, we create the final replacement
181+
// (we have to be careful and use FieldRef/Extract/Read based on whether the input and
182+
// output are references or not)
183+
if (isLeaf) {
184+
bool outputIsRef = isa<P4HIR::ReferenceType>(op.getResult().getType());
185+
Value newResult =
186+
llvm::TypeSwitch<Type, Value>(convertedInput.getType())
187+
.Case<P4HIR::ReferenceType>([&](auto) {
188+
Value res = rewriter
189+
.create<P4HIR::StructFieldRefOp>(
190+
op.getLoc(), convertedInput, currentFieldPath)
191+
.getResult();
192+
if (!outputIsRef)
193+
res = rewriter.create<P4HIR::ReadOp>(op.getLoc(), res).getResult();
194+
return res;
195+
})
196+
.Default([&](auto) {
197+
return rewriter
198+
.create<P4HIR::StructExtractOp>(op.getLoc(), convertedInput,
199+
currentFieldPath)
200+
.getResult();
201+
});
202+
203+
replacements[op.getOperation()] = newResult;
204+
return success();
205+
}
206+
207+
// This is not a leaf, process the rest of the tree and add the op the erase list
208+
for (auto user : op.getResult().getUsers()) {
209+
if (!isa<P4HIR::StructFieldRefOp, P4HIR::StructExtractOp, P4HIR::ReadOp>(user)) {
210+
return user->emitError(
211+
"Expected struct access or read operation as user of intermediate struct "
212+
"access");
213+
}
214+
215+
if (failed(processStructAccessTree(user, convertedInput, currentFieldPath, eraseList,
216+
replacements, rewriter)))
217+
return failure();
218+
}
219+
220+
eraseList.push_back(op);
221+
return success();
222+
}
223+
224+
LogicalResult processRead(P4HIR::ReadOp op, Value convertedInput, StringRef parentFieldPath,
225+
SmallVector<Operation *> &eraseList,
226+
DenseMap<Operation *, Value> &replacements,
227+
ConversionPatternRewriter &rewriter) const {
228+
// For intermediate reads we just process the rest of the tree and add the read to the erase
229+
// list
230+
for (auto user : op.getResult().getUsers()) {
231+
if (!isa<P4HIR::StructFieldRefOp, P4HIR::StructExtractOp>(user)) {
232+
return user->emitError("Expected struct access operation as user of read");
233+
}
234+
235+
if (failed(processStructAccessTree(user, convertedInput, parentFieldPath, eraseList,
236+
replacements, rewriter)))
237+
return failure();
238+
}
239+
240+
eraseList.push_back(op);
241+
return success();
242+
}
243+
};
244+
245+
struct StructFlatteningPass : public P4::P4MLIR::impl::FlattenStructsBase<StructFlatteningPass> {
246+
void runOnOperation() override {
247+
auto moduleOp = getOperation();
248+
249+
StructFlatteningTypeConverter typeConverter;
250+
251+
ConversionTarget target(getContext());
252+
target.addLegalDialect<P4HIR::P4HIRDialect>();
253+
254+
target.addDynamicallyLegalOp<P4HIR::StructFieldRefOp>([&](P4HIR::StructFieldRefOp op) {
255+
return typeConverter.isLegal(op.getInput().getType());
256+
});
257+
258+
target.addDynamicallyLegalOp<P4HIR::StructExtractOp>([&](P4HIR::StructExtractOp op) {
259+
return typeConverter.isLegal(op.getInput().getType());
260+
});
261+
262+
target.addDynamicallyLegalOp<P4HIR::FuncOp>(
263+
[&](P4HIR::FuncOp op) { return typeConverter.isLegal(op.getFunctionType()); });
264+
265+
target.addDynamicallyLegalOp<P4HIR::TableKeyOp>(
266+
[&](P4HIR::TableKeyOp op) { return typeConverter.isLegal(op.getApplyType()); });
267+
268+
target.addDynamicallyLegalOp<P4HIR::ControlOp>(
269+
[&](P4HIR::ControlOp op) { return typeConverter.isLegal(op.getApplyType()); });
270+
271+
target.addDynamicallyLegalOp<P4HIR::ParserOp>(
272+
[&](P4HIR::ParserOp op) { return typeConverter.isLegal(op.getApplyType()); });
273+
274+
target.addDynamicallyLegalOp<P4HIR::CallOp>([&](P4HIR::CallOp op) {
275+
return typeConverter.isLegal(op.getResultTypes()) &&
276+
typeConverter.isLegal(op.getOperandTypes());
277+
});
278+
279+
target.addDynamicallyLegalOp<P4HIR::CallMethodOp>([&](P4HIR::CallMethodOp op) {
280+
return typeConverter.isLegal(op.getResultTypes()) &&
281+
typeConverter.isLegal(op.getOperandTypes());
282+
});
283+
284+
target.addDynamicallyLegalOp<P4HIR::TableApplyOp>(
285+
[&](P4HIR::TableApplyOp op) { return typeConverter.isLegal(op.getOperandTypes()); });
286+
287+
target.addDynamicallyLegalOp<P4HIR::ConstructOp>(
288+
[&](P4HIR::ConstructOp op) { return typeConverter.isLegal(op.getType()); });
289+
290+
target.addDynamicallyLegalOp<P4HIR::VariableOp>(
291+
[&](P4HIR::VariableOp op) { return typeConverter.isLegal(op.getType()); });
292+
293+
target.addDynamicallyLegalOp<P4HIR::ControlLocalOp>(
294+
[&](P4HIR::ControlLocalOp op) { return typeConverter.isLegal(op.getVal().getType()); });
295+
296+
target.addDynamicallyLegalOp<P4HIR::SymToValueOp>(
297+
[&](P4HIR::SymToValueOp op) { return typeConverter.isLegal(op.getType()); });
298+
299+
target.addDynamicallyLegalOp<P4HIR::ReadOp>(
300+
[&](P4HIR::ReadOp op) { return typeConverter.isLegal(op.getType()); });
301+
302+
target.addDynamicallyLegalOp<P4HIR::InstantiateOp>([&](P4HIR::InstantiateOp op) {
303+
auto types = op.getTypeParameters();
304+
if (!types) return true;
305+
return llvm::all_of(types.value(), [&](Attribute ty) {
306+
return typeConverter.isLegal(cast<TypeAttr>(ty).getValue());
307+
});
308+
});
309+
310+
RewritePatternSet patterns(&getContext());
311+
patterns.add<FlattenStructAccess<P4HIR::StructFieldRefOp>,
312+
FlattenStructAccess<P4HIR::StructExtractOp>>(typeConverter, &getContext());
313+
populateOpTypeConversionPattern<
314+
P4HIR::FuncOp, P4HIR::ConstructOp, P4HIR::CallOp, P4HIR::CallMethodOp,
315+
P4HIR::TableApplyOp, P4HIR::ControlApplyOp, P4HIR::VariableOp, P4HIR::ControlLocalOp,
316+
P4HIR::TableKeyOp, P4HIR::ControlOp, P4HIR::ParserOp, P4HIR::ReadOp,
317+
P4HIR::InstantiateOp, P4HIR::SymToValueOp>(patterns, typeConverter);
318+
319+
populateFunctionOpInterfaceTypeConversionPattern<P4HIR::FuncOp>(patterns, typeConverter);
320+
321+
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
322+
signalPassFailure();
323+
}
324+
}
325+
};
326+
327+
} // namespace
328+
329+
std::unique_ptr<Pass> createFlattenStructsPass() {
330+
return std::make_unique<StructFlatteningPass>();
331+
}
332+
} // namespace P4::P4MLIR

0 commit comments

Comments
 (0)