Skip to content

Commit b30b5b3

Browse files
authored
feat(avm): tagged value type in C++ (#13540)
The main new class is in `tagged_value.{hpp,cpp}`. You can see its use in `memory.{hpp,cpp}` and... everywhere. It was already all over the place so it's a good thing that we tackle it now. It was a bit of a pain. ## AI generated description This PR introduces a new `TaggedValue` class to replace the previous memory value representation. The `TaggedValue` class encapsulates both a value and its type tag, providing a more robust and type-safe way to handle different data types in the VM. Key changes: - Added `TaggedValue` class that uses a variant to store different numeric types (uint1_t, uint8_t, uint16_t, etc.) - Implemented a new `uint1_t` class to represent boolean values - Updated memory operations to use `TaggedValue` instead of separate value and tag parameters - Modified bitwise operations to work with the new `TaggedValue` type - Updated all related code to use the new type system, including tests and simulation code
1 parent 9a73c4a commit b30b5b3

40 files changed

+1662
-552
lines changed

barretenberg/cpp/src/barretenberg/vm2/common/memory_types.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "barretenberg/vm2/common/memory_types.hpp"
22

3+
#include <cassert>
4+
#include <stdexcept>
5+
36
namespace bb::avm2 {
47

58
uint8_t integral_tag_length(MemoryTag tag)
@@ -24,4 +27,4 @@ uint8_t integral_tag_length(MemoryTag tag)
2427
return 0; // Should never happen. To please the compiler.
2528
}
2629

27-
} // namespace bb::avm2
30+
} // namespace bb::avm2

barretenberg/cpp/src/barretenberg/vm2/common/memory_types.hpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,15 @@
22

33
#include <cstdint>
44

5-
#include "barretenberg/vm2/common/field.hpp"
5+
#include "barretenberg/vm2/common/tagged_value.hpp"
66

77
namespace bb::avm2 {
88

9-
enum class MemoryTag {
10-
FF,
11-
U1,
12-
U8,
13-
U16,
14-
U32,
15-
U64,
16-
U128,
17-
MAX = U128,
18-
};
19-
9+
using MemoryTag = ValueTag;
10+
using MemoryValue = TaggedValue;
2011
using MemoryAddress = uint32_t;
21-
using MemoryValue = FF;
2212
constexpr auto MemoryAddressTag = MemoryTag::U32;
2313

2414
uint8_t integral_tag_length(MemoryTag tag);
2515

26-
} // namespace bb::avm2
16+
} // namespace bb::avm2
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
#include "barretenberg/vm2/common/tagged_value.hpp"
2+
3+
#include <cassert>
4+
#include <functional>
5+
#include <stdexcept>
6+
#include <variant>
7+
8+
#include "barretenberg/numeric/bitop/get_msb.hpp"
9+
#include "barretenberg/numeric/uint128/uint128.hpp"
10+
#include "barretenberg/numeric/uint256/uint256.hpp"
11+
#include "barretenberg/vm2/common/stringify.hpp"
12+
#include "barretenberg/vm2/common/uint1.hpp"
13+
14+
namespace bb::avm2 {
15+
16+
namespace {
17+
18+
// Helper type for ad-hoc visitors. See https://en.cppreference.com/w/cpp/utility/variant/visit2.
19+
template <class... Ts> struct overloads : Ts... {
20+
using Ts::operator()...;
21+
};
22+
// This is a deduction guide. Apparently not needed in C++20, but we somehow still need it.
23+
template <class... Ts> overloads(Ts...) -> overloads<Ts...>;
24+
25+
struct shift_left {
26+
template <typename T, typename U> T operator()(const T& a, const U& b) const
27+
{
28+
if constexpr (std::is_same_v<T, uint1_t>) {
29+
return static_cast<T>(a.operator<<(b));
30+
} else {
31+
return static_cast<T>(a << b);
32+
}
33+
}
34+
};
35+
36+
struct shift_right {
37+
template <typename T, typename U> T operator()(const T& a, const U& b) const
38+
{
39+
if constexpr (std::is_same_v<T, uint1_t>) {
40+
return static_cast<T>(a.operator>>(b));
41+
} else {
42+
return static_cast<T>(a >> b);
43+
}
44+
}
45+
};
46+
47+
template <typename Op>
48+
constexpr bool is_bitwise_operation_v =
49+
std::is_same_v<Op, std::bit_and<>> || std::is_same_v<Op, std::bit_or<>> || std::is_same_v<Op, std::bit_xor<>> ||
50+
std::is_same_v<Op, std::bit_not<>> || std::is_same_v<Op, shift_left> || std::is_same_v<Op, shift_right>;
51+
52+
// Helper visitor for binary operations. These assume that the types are the same.
53+
template <typename Op> struct BinaryOperationVisitor {
54+
template <typename T, typename U> TaggedValue::value_type operator()(const T& a, const U& b) const
55+
{
56+
if constexpr (std::is_same_v<T, U>) {
57+
if constexpr (std::is_same_v<T, FF> && is_bitwise_operation_v<Op>) {
58+
throw std::runtime_error("Bitwise operations not valid for FF");
59+
} else {
60+
// Note: IDK why this static_cast is needed, but if you remove it, operations seem to go through FF.
61+
return static_cast<T>(Op{}(a, b));
62+
}
63+
} else {
64+
throw std::runtime_error("Type mismatch in BinaryOperationVisitor!");
65+
}
66+
}
67+
};
68+
69+
// Helper visitor for shift operations. The right hand side is a different type.
70+
template <typename Op> struct ShiftOperationVisitor {
71+
template <typename T, typename U> TaggedValue::value_type operator()(const T& a, const U& b) const
72+
{
73+
if constexpr (std::is_same_v<T, FF> || std::is_same_v<U, FF>) {
74+
throw std::runtime_error("Bitwise operations not valid for FF");
75+
} else {
76+
return static_cast<T>(Op{}(a, b));
77+
}
78+
}
79+
};
80+
81+
// Helper visitor for unary operations
82+
template <typename Op> struct UnaryOperationVisitor {
83+
template <typename T> TaggedValue::value_type operator()(const T& a) const
84+
{
85+
if constexpr (std::is_same_v<T, FF> && is_bitwise_operation_v<Op>) {
86+
throw std::runtime_error("Can't do unary bitwise operations on an FF");
87+
} else {
88+
// Note: IDK why this static_cast is needed, but if you remove it, operations seem to go through FF.
89+
return static_cast<T>(Op{}(a));
90+
}
91+
}
92+
};
93+
94+
} // namespace
95+
96+
// Constructor
97+
TaggedValue::TaggedValue(TaggedValue::value_type value_)
98+
: value(value_)
99+
{}
100+
101+
TaggedValue TaggedValue::from_tag(ValueTag tag, FF value)
102+
{
103+
auto assert_bounds = [](const FF& value, uint8_t bits) {
104+
if (static_cast<uint256_t>(value).get_msb() >= bits) {
105+
throw std::runtime_error("Value out of bounds");
106+
}
107+
};
108+
109+
// Check bounds first.
110+
switch (tag) {
111+
case ValueTag::U1:
112+
assert_bounds(value, 1);
113+
break;
114+
case ValueTag::U8:
115+
assert_bounds(value, 8);
116+
break;
117+
case ValueTag::U16:
118+
assert_bounds(value, 16);
119+
break;
120+
case ValueTag::U32:
121+
assert_bounds(value, 32);
122+
break;
123+
case ValueTag::U64:
124+
assert_bounds(value, 64);
125+
break;
126+
case ValueTag::U128:
127+
assert_bounds(value, 128);
128+
break;
129+
case ValueTag::FF:
130+
break;
131+
}
132+
133+
return from_tag_truncating(tag, value);
134+
}
135+
136+
TaggedValue TaggedValue::from_tag_truncating(ValueTag tag, FF value)
137+
{
138+
switch (tag) {
139+
case ValueTag::U1:
140+
return TaggedValue(static_cast<uint1_t>(static_cast<uint8_t>(value) % 2));
141+
case ValueTag::U8:
142+
return TaggedValue(static_cast<uint8_t>(value));
143+
case ValueTag::U16:
144+
return TaggedValue(static_cast<uint16_t>(value));
145+
case ValueTag::U32:
146+
return TaggedValue(static_cast<uint32_t>(value));
147+
case ValueTag::U64:
148+
return TaggedValue(static_cast<uint64_t>(value));
149+
case ValueTag::U128:
150+
return TaggedValue(static_cast<uint128_t>(value));
151+
case ValueTag::FF:
152+
return TaggedValue(value);
153+
default:
154+
throw std::runtime_error("Invalid tag");
155+
}
156+
}
157+
158+
// Arithmetic operators
159+
TaggedValue TaggedValue::operator+(const TaggedValue& other) const
160+
{
161+
return std::visit(BinaryOperationVisitor<std::plus<>>(), value, other.value);
162+
}
163+
164+
TaggedValue TaggedValue::operator-(const TaggedValue& other) const
165+
{
166+
return std::visit(BinaryOperationVisitor<std::minus<>>(), value, other.value);
167+
}
168+
169+
TaggedValue TaggedValue::operator*(const TaggedValue& other) const
170+
{
171+
return std::visit(BinaryOperationVisitor<std::multiplies<>>(), value, other.value);
172+
}
173+
174+
TaggedValue TaggedValue::operator/(const TaggedValue& other) const
175+
{
176+
return std::visit(BinaryOperationVisitor<std::divides<>>(), value, other.value);
177+
}
178+
179+
// Bitwise operators
180+
TaggedValue TaggedValue::operator&(const TaggedValue& other) const
181+
{
182+
return std::visit(BinaryOperationVisitor<std::bit_and<>>(), value, other.value);
183+
}
184+
185+
TaggedValue TaggedValue::operator|(const TaggedValue& other) const
186+
{
187+
return std::visit(BinaryOperationVisitor<std::bit_or<>>(), value, other.value);
188+
}
189+
190+
TaggedValue TaggedValue::operator^(const TaggedValue& other) const
191+
{
192+
return std::visit(BinaryOperationVisitor<std::bit_xor<>>(), value, other.value);
193+
}
194+
195+
TaggedValue TaggedValue::operator<<(const TaggedValue& other) const
196+
{
197+
return std::visit(ShiftOperationVisitor<shift_left>(), value, other.value);
198+
}
199+
200+
TaggedValue TaggedValue::operator>>(const TaggedValue& other) const
201+
{
202+
return std::visit(ShiftOperationVisitor<shift_right>(), value, other.value);
203+
}
204+
205+
TaggedValue TaggedValue::operator~() const
206+
{
207+
return std::visit(UnaryOperationVisitor<std::bit_not<>>(), value);
208+
}
209+
210+
FF TaggedValue::as_ff() const
211+
{
212+
const auto visitor = overloads{ [](FF val) -> FF { return val; },
213+
[](uint1_t val) -> FF { return val.value(); },
214+
[](uint128_t val) -> FF { return uint256_t::from_uint128(val); },
215+
[](auto&& val) -> FF { return val; } };
216+
217+
return std::visit(visitor, value);
218+
}
219+
220+
ValueTag TaggedValue::get_tag() const
221+
{
222+
// The tag is implicit in the type.
223+
if (std::holds_alternative<uint8_t>(value)) {
224+
return ValueTag::U8;
225+
} else if (std::holds_alternative<uint1_t>(value)) {
226+
return ValueTag::U1;
227+
} else if (std::holds_alternative<uint16_t>(value)) {
228+
return ValueTag::U16;
229+
} else if (std::holds_alternative<uint32_t>(value)) {
230+
return ValueTag::U32;
231+
} else if (std::holds_alternative<uint64_t>(value)) {
232+
return ValueTag::U64;
233+
} else if (std::holds_alternative<uint128_t>(value)) {
234+
return ValueTag::U128;
235+
} else if (std::holds_alternative<FF>(value)) {
236+
return ValueTag::FF;
237+
} else {
238+
throw std::runtime_error("Unknown value type");
239+
}
240+
241+
assert(false && "This should never happen.");
242+
return ValueTag::FF; // Only to make the compiler happy.
243+
}
244+
245+
std::string TaggedValue::to_string() const
246+
{
247+
std::string v = std::visit(
248+
overloads{ [](const FF& val) -> std::string { return field_to_string(val); },
249+
[](const uint128_t& val) -> std::string { return field_to_string(uint256_t::from_uint128(val)); },
250+
[](const uint1_t& val) -> std::string { return val.value() == 0 ? "0" : "1"; },
251+
[](auto&& val) -> std::string { return std::to_string(val); } },
252+
value);
253+
return "TaggedValue(" + v + ", " + std::to_string(get_tag()) + ")";
254+
}
255+
256+
} // namespace bb::avm2
257+
258+
std::string std::to_string(bb::avm2::ValueTag tag)
259+
{
260+
using namespace bb::avm2;
261+
switch (tag) {
262+
case ValueTag::U1:
263+
return "U1";
264+
case ValueTag::U8:
265+
return "U8";
266+
case ValueTag::U16:
267+
return "U16";
268+
case ValueTag::U32:
269+
return "U32";
270+
case ValueTag::U64:
271+
return "U64";
272+
case ValueTag::U128:
273+
return "U128";
274+
case ValueTag::FF:
275+
return "FF";
276+
default:
277+
return "Unknown";
278+
}
279+
}
280+
281+
std::string std::to_string(const bb::avm2::TaggedValue& value)
282+
{
283+
return value.to_string();
284+
}

0 commit comments

Comments
 (0)