Skip to content

Commit 30b942f

Browse files
committed
Use same instance for same struct type
1 parent 280e754 commit 30b942f

File tree

2 files changed

+101
-37
lines changed

2 files changed

+101
-37
lines changed

lib/Conversion/P4HIRToBMv2IR/LowerToHeaderInstance.cpp

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
#include "llvm/ADT/DenseMap.h"
12
#include "llvm/ADT/STLExtras.h"
23
#include "llvm/ADT/SmallPtrSet.h"
34
#include "llvm/ADT/SmallVector.h"
45
#include "llvm/ADT/StringMap.h"
6+
#include "llvm/ADT/StringRef.h"
57
#include "llvm/ADT/TypeSwitch.h"
68
#include "llvm/Support/Casting.h"
79
#include "llvm/Support/LogicalResult.h"
810
#include "mlir/IR/BuiltinAttributeInterfaces.h"
911
#include "mlir/IR/BuiltinAttributes.h"
1012
#include "mlir/IR/BuiltinOps.h"
1113
#include "mlir/IR/Diagnostics.h"
14+
#include "mlir/IR/MLIRContext.h"
1215
#include "mlir/IR/PatternMatch.h"
1316
#include "mlir/IR/SymbolTable.h"
1417
#include "mlir/IR/Value.h"
@@ -57,15 +60,52 @@ P4HIR::HeaderType isHeaderOrRefToHeader(mlir::Type ty) {
5760
return nullptr;
5861
}
5962

63+
// Map for <struct type name, field name> -> header instance
64+
using InstanceConversionContext =
65+
std::map<std::pair<mlir::Type *, std::string>, BMv2IR::HeaderInstanceOp>;
66+
// Map used for raw headers in control/parser args
67+
using HeaderConversionContext = llvm::DenseMap<P4HIR::HeaderType, BMv2IR::HeaderInstanceOp>;
68+
6069
// Adds instances from a StructType, splitting the struct to create separate instances for
6170
// the header fields, and creating a new struct containing only the bit fields if necessary
6271
LogicalResult splitStructAndAddInstances(Value val, P4HIR::StructType structTy, Location loc,
6372
StringRef parentName, ModuleOp moduleOp,
64-
PatternRewriter &rewriter) {
73+
PatternRewriter &rewriter,
74+
InstanceConversionContext *instances) {
6575
auto ctx = rewriter.getContext();
66-
llvm::StringMap<BMv2IR::HeaderInstanceOp> instances;
76+
llvm::StringMap<BMv2IR::HeaderInstanceOp> localInstance;
6777
SmallPtrSet<Operation *, 5> fieldRefs;
6878
SmallPtrSet<Operation *, 5> bitRefs;
79+
auto getOrInsertInstance = [&localInstance, &rewriter, &instances, &moduleOp, &loc,
80+
&parentName](P4HIR::StructType ty, StringRef name,
81+
bool referenceFullStruct) -> BMv2IR::HeaderInstanceOp {
82+
auto instanceTy = referenceFullStruct ? P4HIR::ReferenceType::get(ty)
83+
: P4HIR::ReferenceType::get(ty.getFieldType(name));
84+
auto symName = referenceFullStruct ? rewriter.getStringAttr(ty.getName())
85+
: rewriter.getStringAttr(ty.getName() + "_" + name);
86+
if (instances) {
87+
auto it = instances->find({&ty, name.str()});
88+
if (it != instances->end()) {
89+
return it->second;
90+
}
91+
PatternRewriter::InsertionGuard guard(rewriter);
92+
rewriter.setInsertionPointToStart(moduleOp.getBody());
93+
auto instanceOp = rewriter.create<BMv2IR::HeaderInstanceOp>(loc, symName, instanceTy);
94+
instances->insert({{&ty, name.str()}, instanceOp});
95+
return instanceOp;
96+
}
97+
// We don't want to use the global context for Header Instances (e.g. we are lowering a
98+
// P4HIR::Variable) We still need to avoid adding duplicate HeaderInstances if the same
99+
// field of a variable is accessed multiple times, so we use the "local" map
100+
auto it = localInstance.find(name);
101+
if (it != localInstance.end()) return it->second;
102+
PatternRewriter::InsertionGuard guard(rewriter);
103+
rewriter.setInsertionPointToStart(moduleOp.getBody());
104+
auto instanceOp = rewriter.create<BMv2IR::HeaderInstanceOp>(
105+
loc, rewriter.getStringAttr(parentName + "_" + name), instanceTy);
106+
localInstance.insert({name, instanceOp});
107+
return instanceOp;
108+
};
69109

70110
// Find the StructFieldRefOps that access the struct
71111
for (auto user : val.getUsers()) {
@@ -86,18 +126,7 @@ LogicalResult splitStructAndAddInstances(Value val, P4HIR::StructType structTy,
86126
for (auto op : fieldRefs) {
87127
auto fieldRefOp = cast<P4HIR::StructFieldRefOp>(op);
88128
auto name = fieldRefOp.getFieldName();
89-
auto instance = instances.find(name);
90-
BMv2IR::HeaderInstanceOp instanceOp = nullptr;
91-
PatternRewriter::InsertionGuard guard(rewriter);
92-
if (instance != instances.end()) {
93-
instanceOp = instance->second;
94-
} else {
95-
rewriter.setInsertionPointToStart(moduleOp.getBody());
96-
instanceOp = rewriter.create<BMv2IR::HeaderInstanceOp>(
97-
loc, rewriter.getStringAttr(parentName + "_" + name),
98-
P4HIR::ReferenceType::get(structTy.getFieldType(name)));
99-
instances.insert({name, instanceOp});
100-
}
129+
auto instanceOp = getOrInsertInstance(structTy, name.str(), false);
101130
rewriter.setInsertionPointAfter(fieldRefOp);
102131
rewriter.replaceOpWithNewOp<BMv2IR::SymToValueOp>(
103132
fieldRefOp, instanceOp.getHeaderType(),
@@ -107,7 +136,9 @@ LogicalResult splitStructAndAddInstances(Value val, P4HIR::StructType structTy,
107136
if (bitRefs.empty()) return success();
108137

109138
// Since the struct has bit fields, we create a new type dropping the header fields, and add a
110-
// header instance for it
139+
// header instance for it. Note that this behaves differently from p4c, which creates a single
140+
// header instance containing all the scalars (and separate header instance for every varbit
141+
// since an header instance can contain only one varbit field).
111142

112143
SmallVector<P4HIR::FieldInfo> bitFields;
113144
for (auto field : structTy.getFields()) {
@@ -118,8 +149,7 @@ LogicalResult splitStructAndAddInstances(Value val, P4HIR::StructType structTy,
118149
rewriter.setInsertionPointToStart(moduleOp.getBody());
119150
auto newTy = P4HIR::StructType::get(rewriter.getContext(), structTy.getName(), bitFields,
120151
structTy.getAnnotations());
121-
auto newInstance = rewriter.create<BMv2IR::HeaderInstanceOp>(
122-
loc, rewriter.getStringAttr(parentName), P4HIR::ReferenceType::get(newTy));
152+
auto newInstance = getOrInsertInstance(newTy, newTy.getName(), true);
123153
for (auto op : bitRefs) {
124154
auto fieldRefOp = cast<P4HIR::StructFieldRefOp>(op);
125155
rewriter.setInsertionPointAfter(fieldRefOp);
@@ -149,12 +179,20 @@ LogicalResult addInstanceForHeader(Operation *op, P4HIR::HeaderType headerTy, Tw
149179
return success();
150180
}
151181

152-
LogicalResult addInstanceForHeader(BlockArgument arg, P4HIR::HeaderType headerTy, Twine name,
153-
ModuleOp moduleOp, PatternRewriter &rewriter) {
182+
LogicalResult addInstanceForHeader(BlockArgument arg, P4HIR::HeaderType headerTy, ModuleOp moduleOp,
183+
PatternRewriter &rewriter, HeaderConversionContext *instances) {
154184
PatternRewriter::InsertionGuard guard(rewriter);
155185
rewriter.setInsertionPointToStart(moduleOp.getBody());
156-
auto newInstance = rewriter.create<BMv2IR::HeaderInstanceOp>(
157-
arg.getLoc(), rewriter.getStringAttr(name), P4HIR::ReferenceType::get(headerTy));
186+
BMv2IR::HeaderInstanceOp newInstance;
187+
auto it = instances->find(headerTy);
188+
if (it != instances->end()) {
189+
newInstance = it->second;
190+
} else {
191+
newInstance = rewriter.create<BMv2IR::HeaderInstanceOp>(
192+
arg.getLoc(), rewriter.getStringAttr(headerTy.getName()),
193+
P4HIR::ReferenceType::get(headerTy));
194+
instances->insert({headerTy, newInstance});
195+
}
158196

159197
for (auto &use : arg.getUses()) {
160198
Operation *user = use.getOwner();
@@ -171,7 +209,11 @@ LogicalResult addInstanceForHeader(BlockArgument arg, P4HIR::HeaderType headerTy
171209
}
172210

173211
struct ParserOpPattern : public OpRewritePattern<P4HIR::ParserOp> {
174-
using OpRewritePattern<P4HIR::ParserOp>::OpRewritePattern;
212+
ParserOpPattern(MLIRContext *context, InstanceConversionContext *instances,
213+
HeaderConversionContext *instancesFromHeaderArgs)
214+
: OpRewritePattern<P4HIR::ParserOp>(context),
215+
instances(instances),
216+
instancesFromHeaderArgs(instancesFromHeaderArgs) {}
175217

176218
mlir::LogicalResult matchAndRewrite(P4HIR::ParserOp parserOp,
177219
mlir::PatternRewriter &rewriter) const override {
@@ -183,17 +225,22 @@ struct ParserOpPattern : public OpRewritePattern<P4HIR::ParserOp> {
183225
std::string parentName =
184226
(parserOp.getSymName() + std::to_string(arg.getArgNumber())).str();
185227
if (auto headerTy = isHeaderOrRefToHeader(ty)) {
186-
if (failed(addInstanceForHeader(arg, headerTy, parentName, moduleOp, rewriter)))
228+
if (failed(addInstanceForHeader(arg, headerTy, moduleOp, rewriter,
229+
instancesFromHeaderArgs)))
187230
return parserOp->emitError("Failed to process parserOp");
188231
} else if (auto structTy = isStructOrRefToStruct(ty)) {
189232
if (failed(splitStructAndAddInstances(arg, structTy, parserOp.getLoc(), parentName,
190-
moduleOp, rewriter)))
233+
moduleOp, rewriter, instances)))
191234
return parserOp->emitError("Failed to process parserOp");
192235
}
193236
}
194237

195238
return mlir::success();
196239
}
240+
241+
private:
242+
InstanceConversionContext *instances;
243+
HeaderConversionContext *instancesFromHeaderArgs;
197244
};
198245

199246
FailureOr<StringRef> getParentName(Operation *op) {
@@ -223,7 +270,7 @@ struct VariableOpPattern : public OpRewritePattern<P4HIR::VariableOp> {
223270
.Case([&](P4HIR::StructType structTy) -> LogicalResult {
224271
if (failed(splitStructAndAddInstances(variableOp.getResult(), structTy,
225272
variableOp.getLoc(), name, moduleOp,
226-
rewriter)))
273+
rewriter, nullptr)))
227274
return variableOp.emitError("Error translating variableOp");
228275
return success();
229276
})
@@ -260,7 +307,10 @@ struct LowerToHeaderInstancePass
260307
});
261308

262309
// TODO: add support for controls and other ops that may lead to header instances
263-
patterns.add<ParserOpPattern, VariableOpPattern>(patterns.getContext());
310+
InstanceConversionContext instances;
311+
HeaderConversionContext instancesFromHeaderArgs;
312+
patterns.add<VariableOpPattern>(patterns.getContext());
313+
patterns.add<ParserOpPattern>(patterns.getContext(), &instances, &instancesFromHeaderArgs);
264314

265315
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
266316
signalPassFailure();

test/Conversion/BMv2IR/lower-to-header-instance.mlir

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515
!header_and_bit = !p4hir.struct<"header_and_bit", top: !header_top, bit: !b8i>
1616
module {
1717
// CHECK: bmv2ir.header_instance @prs_e_0 : !p4hir.ref<!header_one>
18-
// CHECK: bmv2ir.header_instance @prs1_top : !p4hir.ref<!header_top>
19-
// CHECK: bmv2ir.header_instance @prs1_one : !p4hir.ref<!header_one>
20-
// CHECK: bmv2ir.header_instance @prs1_two : !p4hir.ref<!header_two>
18+
// CHECK: bmv2ir.header_instance @Headers_t_top : !p4hir.ref<!header_top>
19+
// CHECK: bmv2ir.header_instance @Headers_t_one : !p4hir.ref<!header_one>
20+
// CHECK: bmv2ir.header_instance @Headers_t_two : !p4hir.ref<!header_two>
2121
p4hir.parser @prs(%arg0: !p4corelib.packet_in {p4hir.dir = #p4hir<dir undir>, p4hir.param_name = "p"}, %arg1: !p4hir.ref<!Headers_t> {p4hir.dir = #p4hir<dir out>, p4hir.param_name = "headers"})() {
2222
%e_0 = p4hir.variable ["e_0"] annotations {name = "ParserImpl.e"} : <!header_one>
2323
// CHECK: %[[E_0:.*]] = bmv2ir.symbol_ref @prs_e_0 : !p4hir.ref<!header_one>
2424
p4hir.state @start {
2525
%top_field_ref = p4hir.struct_field_ref %arg1["top"] : <!Headers_t>
2626
p4corelib.extract_header %top_field_ref : <!header_top> from %arg0 : !p4corelib.packet_in
27-
// CHECK: %[[REF:.*]] = bmv2ir.symbol_ref @prs1_top : !p4hir.ref<!header_top>
27+
// CHECK: %[[REF:.*]] = bmv2ir.symbol_ref @Headers_t_top : !p4hir.ref<!header_top>
2828
// CHECK: p4corelib.extract_header %[[REF]] : <!header_top> from %arg0 : !p4corelib.packet_in
2929
p4hir.transition to @prs::@parse_headers
3030
}
@@ -54,7 +54,7 @@ module {
5454
}
5555
}
5656
p4hir.state @parse_one {
57-
// CHECK: %[[REF2:.*]] = bmv2ir.symbol_ref @prs1_one : !p4hir.ref<!header_one>
57+
// CHECK: %[[REF2:.*]] = bmv2ir.symbol_ref @Headers_t_one : !p4hir.ref<!header_one>
5858
%one_field_ref = p4hir.struct_field_ref %arg1["one"] : <!Headers_t>
5959
p4corelib.extract_header %e_0 : <!header_one> from %arg0 : !p4corelib.packet_in
6060
// CHECK: p4corelib.extract_header %[[E_0]] : <!header_one> from %arg0 : !p4corelib.packet_in
@@ -64,7 +64,7 @@ module {
6464
p4hir.transition to @prs::@parse_two
6565
}
6666
p4hir.state @parse_two {
67-
// CHECK: %[[REF3:.*]] = bmv2ir.symbol_ref @prs1_two : !p4hir.ref<!header_two>
67+
// CHECK: %[[REF3:.*]] = bmv2ir.symbol_ref @Headers_t_two : !p4hir.ref<!header_two>
6868
%two_field_ref = p4hir.struct_field_ref %arg1["two"] : <!Headers_t>
6969
p4corelib.extract_header %two_field_ref : <!header_two> from %arg0 : !p4corelib.packet_in
7070
// CHECK: p4corelib.extract_header %[[REF3]] : <!header_two> from %arg0 : !p4corelib.packet_in
@@ -78,6 +78,20 @@ module {
7878
}
7979
p4hir.transition to @prs::@start
8080
}
81+
// Check that the Headers_t arg now leads to the same instance being used for its fields
82+
p4hir.parser @other(%arg0: !p4corelib.packet_in {p4hir.dir = #p4hir<dir undir>, p4hir.param_name = "p"}, %arg1: !p4hir.ref<!Headers_t> {p4hir.dir = #p4hir<dir out>, p4hir.param_name = "headers"})() {
83+
p4hir.state @start {
84+
%top_field_ref = p4hir.struct_field_ref %arg1["top"] : <!Headers_t>
85+
p4corelib.extract_header %top_field_ref : <!header_top> from %arg0 : !p4corelib.packet_in
86+
// CHECK: %[[REF:.*]] = bmv2ir.symbol_ref @Headers_t_top : !p4hir.ref<!header_top>
87+
// CHECK: p4corelib.extract_header %[[REF]] : <!header_top> from %arg0 : !p4corelib.packet_in
88+
p4hir.transition to @prs::@accept
89+
}
90+
p4hir.state @accept {
91+
p4hir.parser_accept
92+
}
93+
p4hir.transition to @prs::@start
94+
}
8195
}
8296

8397
// -----
@@ -87,10 +101,10 @@ module {
87101
!validity_bit = !p4hir.validity.bit
88102
!header_top = !p4hir.header<"header_top", skip: !b8i, __valid: !validity_bit>
89103
module {
90-
// CHECK: bmv2ir.header_instance @prs_header_arg1 : !p4hir.ref<!header_top>
104+
// CHECK: bmv2ir.header_instance @header_top : !p4hir.ref<!header_top>
91105
p4hir.parser @prs_header_arg(%arg0: !p4corelib.packet_in {p4hir.dir = #p4hir<dir undir>, p4hir.param_name = "p"}, %arg1: !p4hir.ref<!header_top> {p4hir.dir = #p4hir<dir out>, p4hir.param_name = "headers"})() {
92106
p4hir.state @start {
93-
// CHECK: %[[REF:.*]] = bmv2ir.symbol_ref @prs_header_arg1 : !p4hir.ref<!header_top>
107+
// CHECK: %[[REF:.*]] = bmv2ir.symbol_ref @header_top : !p4hir.ref<!header_top>
94108
// CHECK: p4corelib.extract_header %[[REF]] : <!header_top> from %arg0 : !p4corelib.packet_in
95109
p4corelib.extract_header %arg1 : <!header_top> from %arg0 : !p4corelib.packet_in
96110
p4hir.transition to @prs_header_arg::@accept
@@ -112,15 +126,15 @@ module {
112126
!header_and_bit = !p4hir.struct<"header_and_bit", top: !header_top, bit: !b8i>
113127
// CHECK: ![[SPLIT_STRUCT:.*]] = !p4hir.struct<"header_and_bit", bit: !b8i>
114128
module {
115-
// CHECK: bmv2ir.header_instance @prs_header_and_bit1 : !p4hir.ref<![[SPLIT_STRUCT]]>
129+
// CHECK: bmv2ir.header_instance @header_and_bit : !p4hir.ref<![[SPLIT_STRUCT]]>
116130
p4hir.parser @prs_header_and_bit(%arg0: !p4corelib.packet_in {p4hir.dir = #p4hir<dir undir>, p4hir.param_name = "p"}, %arg1: !p4hir.ref<!header_and_bit> {p4hir.dir = #p4hir<dir out>, p4hir.param_name = "headers"})() {
117131
%var = p4hir.variable ["top_0"] annotations {name = "ParserImpl.e"} : <!header_top>
118132
p4hir.state @start {
119133
p4corelib.extract_header %var : <!header_top> from %arg0 : !p4corelib.packet_in
120134
%bit = p4hir.struct_field_ref %var["skip"] : <!header_top>
121135
%val = p4hir.read %bit : <!b8i>
122136
%ref = p4hir.struct_field_ref %arg1["bit"] : <!header_and_bit>
123-
// CHECK: %[[REF:.*]] = bmv2ir.symbol_ref @prs_header_and_bit1 : !p4hir.ref<!header_and_bit>
137+
// CHECK: %[[REF:.*]] = bmv2ir.symbol_ref @header_and_bit : !p4hir.ref<!header_and_bit>
124138
// CHECK: %{{.*}} = p4hir.struct_field_ref %[[REF]]["bit"] : <!header_and_bit>
125139
p4hir.assign %val, %ref : <!b8i>
126140
p4hir.transition to @prs_header_and_bit::@accept
@@ -140,7 +154,7 @@ module {
140154
!header_top = !p4hir.header<"header_top", skip: !b8i, __valid: !validity_bit>
141155
!bit_only = !p4hir.struct<"bit_only", bit: !b8i>
142156
module {
143-
// CHECK: bmv2ir.header_instance @prs_only_bit1 : !p4hir.ref<!bit_only>
157+
// CHECK: bmv2ir.header_instance @bit_only : !p4hir.ref<!bit_only>
144158
p4hir.parser @prs_only_bit(%arg0: !p4corelib.packet_in {p4hir.dir = #p4hir<dir undir>, p4hir.param_name = "p"}, %arg1: !p4hir.ref<!bit_only> {p4hir.dir = #p4hir<dir out>, p4hir.param_name = "headers"}, %arg2: !p4hir.ref<!header_top> {p4hir.dir = #p4hir<dir out>, p4hir.param_name = "headers"})() {
145159
%var = p4hir.variable ["top_0"] annotations {name = "ParserImpl.e"} : <!header_top>
146160
p4hir.state @start {

0 commit comments

Comments
 (0)