|
| 1 | +#include "barretenberg/vm2/common/tagged_value.hpp" |
| 2 | + |
| 3 | +#include <cassert> |
| 4 | +#include <functional> |
| 5 | +#include <stdexcept> |
| 6 | +#include <variant> |
| 7 | + |
| 8 | +#include "barretenberg/numeric/bitop/get_msb.hpp" |
| 9 | +#include "barretenberg/numeric/uint128/uint128.hpp" |
| 10 | +#include "barretenberg/numeric/uint256/uint256.hpp" |
| 11 | +#include "barretenberg/vm2/common/stringify.hpp" |
| 12 | +#include "barretenberg/vm2/common/uint1.hpp" |
| 13 | + |
| 14 | +namespace bb::avm2 { |
| 15 | + |
| 16 | +namespace { |
| 17 | + |
| 18 | +// Helper type for ad-hoc visitors. See https://en.cppreference.com/w/cpp/utility/variant/visit2. |
| 19 | +template <class... Ts> struct overloads : Ts... { |
| 20 | + using Ts::operator()...; |
| 21 | +}; |
| 22 | +// This is a deduction guide. Apparently not needed in C++20, but we somehow still need it. |
| 23 | +template <class... Ts> overloads(Ts...) -> overloads<Ts...>; |
| 24 | + |
| 25 | +struct shift_left { |
| 26 | + template <typename T, typename U> T operator()(const T& a, const U& b) const |
| 27 | + { |
| 28 | + if constexpr (std::is_same_v<T, uint1_t>) { |
| 29 | + return static_cast<T>(a.operator<<(b)); |
| 30 | + } else { |
| 31 | + return static_cast<T>(a << b); |
| 32 | + } |
| 33 | + } |
| 34 | +}; |
| 35 | + |
| 36 | +struct shift_right { |
| 37 | + template <typename T, typename U> T operator()(const T& a, const U& b) const |
| 38 | + { |
| 39 | + if constexpr (std::is_same_v<T, uint1_t>) { |
| 40 | + return static_cast<T>(a.operator>>(b)); |
| 41 | + } else { |
| 42 | + return static_cast<T>(a >> b); |
| 43 | + } |
| 44 | + } |
| 45 | +}; |
| 46 | + |
| 47 | +template <typename Op> |
| 48 | +constexpr bool is_bitwise_operation_v = |
| 49 | + std::is_same_v<Op, std::bit_and<>> || std::is_same_v<Op, std::bit_or<>> || std::is_same_v<Op, std::bit_xor<>> || |
| 50 | + std::is_same_v<Op, std::bit_not<>> || std::is_same_v<Op, shift_left> || std::is_same_v<Op, shift_right>; |
| 51 | + |
| 52 | +// Helper visitor for binary operations. These assume that the types are the same. |
| 53 | +template <typename Op> struct BinaryOperationVisitor { |
| 54 | + template <typename T, typename U> TaggedValue::value_type operator()(const T& a, const U& b) const |
| 55 | + { |
| 56 | + if constexpr (std::is_same_v<T, U>) { |
| 57 | + if constexpr (std::is_same_v<T, FF> && is_bitwise_operation_v<Op>) { |
| 58 | + throw std::runtime_error("Bitwise operations not valid for FF"); |
| 59 | + } else { |
| 60 | + // Note: IDK why this static_cast is needed, but if you remove it, operations seem to go through FF. |
| 61 | + return static_cast<T>(Op{}(a, b)); |
| 62 | + } |
| 63 | + } else { |
| 64 | + throw std::runtime_error("Type mismatch in BinaryOperationVisitor!"); |
| 65 | + } |
| 66 | + } |
| 67 | +}; |
| 68 | + |
| 69 | +// Helper visitor for shift operations. The right hand side is a different type. |
| 70 | +template <typename Op> struct ShiftOperationVisitor { |
| 71 | + template <typename T, typename U> TaggedValue::value_type operator()(const T& a, const U& b) const |
| 72 | + { |
| 73 | + if constexpr (std::is_same_v<T, FF> || std::is_same_v<U, FF>) { |
| 74 | + throw std::runtime_error("Bitwise operations not valid for FF"); |
| 75 | + } else { |
| 76 | + return static_cast<T>(Op{}(a, b)); |
| 77 | + } |
| 78 | + } |
| 79 | +}; |
| 80 | + |
| 81 | +// Helper visitor for unary operations |
| 82 | +template <typename Op> struct UnaryOperationVisitor { |
| 83 | + template <typename T> TaggedValue::value_type operator()(const T& a) const |
| 84 | + { |
| 85 | + if constexpr (std::is_same_v<T, FF> && is_bitwise_operation_v<Op>) { |
| 86 | + throw std::runtime_error("Can't do unary bitwise operations on an FF"); |
| 87 | + } else { |
| 88 | + // Note: IDK why this static_cast is needed, but if you remove it, operations seem to go through FF. |
| 89 | + return static_cast<T>(Op{}(a)); |
| 90 | + } |
| 91 | + } |
| 92 | +}; |
| 93 | + |
| 94 | +} // namespace |
| 95 | + |
| 96 | +// Constructor |
| 97 | +TaggedValue::TaggedValue(TaggedValue::value_type value_) |
| 98 | + : value(value_) |
| 99 | +{} |
| 100 | + |
| 101 | +TaggedValue TaggedValue::from_tag(ValueTag tag, FF value) |
| 102 | +{ |
| 103 | + auto assert_bounds = [](const FF& value, uint8_t bits) { |
| 104 | + if (static_cast<uint256_t>(value).get_msb() >= bits) { |
| 105 | + throw std::runtime_error("Value out of bounds"); |
| 106 | + } |
| 107 | + }; |
| 108 | + |
| 109 | + // Check bounds first. |
| 110 | + switch (tag) { |
| 111 | + case ValueTag::U1: |
| 112 | + assert_bounds(value, 1); |
| 113 | + break; |
| 114 | + case ValueTag::U8: |
| 115 | + assert_bounds(value, 8); |
| 116 | + break; |
| 117 | + case ValueTag::U16: |
| 118 | + assert_bounds(value, 16); |
| 119 | + break; |
| 120 | + case ValueTag::U32: |
| 121 | + assert_bounds(value, 32); |
| 122 | + break; |
| 123 | + case ValueTag::U64: |
| 124 | + assert_bounds(value, 64); |
| 125 | + break; |
| 126 | + case ValueTag::U128: |
| 127 | + assert_bounds(value, 128); |
| 128 | + break; |
| 129 | + case ValueTag::FF: |
| 130 | + break; |
| 131 | + } |
| 132 | + |
| 133 | + return from_tag_truncating(tag, value); |
| 134 | +} |
| 135 | + |
| 136 | +TaggedValue TaggedValue::from_tag_truncating(ValueTag tag, FF value) |
| 137 | +{ |
| 138 | + switch (tag) { |
| 139 | + case ValueTag::U1: |
| 140 | + return TaggedValue(static_cast<uint1_t>(static_cast<uint8_t>(value) % 2)); |
| 141 | + case ValueTag::U8: |
| 142 | + return TaggedValue(static_cast<uint8_t>(value)); |
| 143 | + case ValueTag::U16: |
| 144 | + return TaggedValue(static_cast<uint16_t>(value)); |
| 145 | + case ValueTag::U32: |
| 146 | + return TaggedValue(static_cast<uint32_t>(value)); |
| 147 | + case ValueTag::U64: |
| 148 | + return TaggedValue(static_cast<uint64_t>(value)); |
| 149 | + case ValueTag::U128: |
| 150 | + return TaggedValue(static_cast<uint128_t>(value)); |
| 151 | + case ValueTag::FF: |
| 152 | + return TaggedValue(value); |
| 153 | + default: |
| 154 | + throw std::runtime_error("Invalid tag"); |
| 155 | + } |
| 156 | +} |
| 157 | + |
| 158 | +// Arithmetic operators |
| 159 | +TaggedValue TaggedValue::operator+(const TaggedValue& other) const |
| 160 | +{ |
| 161 | + return std::visit(BinaryOperationVisitor<std::plus<>>(), value, other.value); |
| 162 | +} |
| 163 | + |
| 164 | +TaggedValue TaggedValue::operator-(const TaggedValue& other) const |
| 165 | +{ |
| 166 | + return std::visit(BinaryOperationVisitor<std::minus<>>(), value, other.value); |
| 167 | +} |
| 168 | + |
| 169 | +TaggedValue TaggedValue::operator*(const TaggedValue& other) const |
| 170 | +{ |
| 171 | + return std::visit(BinaryOperationVisitor<std::multiplies<>>(), value, other.value); |
| 172 | +} |
| 173 | + |
| 174 | +TaggedValue TaggedValue::operator/(const TaggedValue& other) const |
| 175 | +{ |
| 176 | + return std::visit(BinaryOperationVisitor<std::divides<>>(), value, other.value); |
| 177 | +} |
| 178 | + |
| 179 | +// Bitwise operators |
| 180 | +TaggedValue TaggedValue::operator&(const TaggedValue& other) const |
| 181 | +{ |
| 182 | + return std::visit(BinaryOperationVisitor<std::bit_and<>>(), value, other.value); |
| 183 | +} |
| 184 | + |
| 185 | +TaggedValue TaggedValue::operator|(const TaggedValue& other) const |
| 186 | +{ |
| 187 | + return std::visit(BinaryOperationVisitor<std::bit_or<>>(), value, other.value); |
| 188 | +} |
| 189 | + |
| 190 | +TaggedValue TaggedValue::operator^(const TaggedValue& other) const |
| 191 | +{ |
| 192 | + return std::visit(BinaryOperationVisitor<std::bit_xor<>>(), value, other.value); |
| 193 | +} |
| 194 | + |
| 195 | +TaggedValue TaggedValue::operator<<(const TaggedValue& other) const |
| 196 | +{ |
| 197 | + return std::visit(ShiftOperationVisitor<shift_left>(), value, other.value); |
| 198 | +} |
| 199 | + |
| 200 | +TaggedValue TaggedValue::operator>>(const TaggedValue& other) const |
| 201 | +{ |
| 202 | + return std::visit(ShiftOperationVisitor<shift_right>(), value, other.value); |
| 203 | +} |
| 204 | + |
| 205 | +TaggedValue TaggedValue::operator~() const |
| 206 | +{ |
| 207 | + return std::visit(UnaryOperationVisitor<std::bit_not<>>(), value); |
| 208 | +} |
| 209 | + |
| 210 | +FF TaggedValue::as_ff() const |
| 211 | +{ |
| 212 | + const auto visitor = overloads{ [](FF val) -> FF { return val; }, |
| 213 | + [](uint1_t val) -> FF { return val.value(); }, |
| 214 | + [](uint128_t val) -> FF { return uint256_t::from_uint128(val); }, |
| 215 | + [](auto&& val) -> FF { return val; } }; |
| 216 | + |
| 217 | + return std::visit(visitor, value); |
| 218 | +} |
| 219 | + |
| 220 | +ValueTag TaggedValue::get_tag() const |
| 221 | +{ |
| 222 | + // The tag is implicit in the type. |
| 223 | + if (std::holds_alternative<uint8_t>(value)) { |
| 224 | + return ValueTag::U8; |
| 225 | + } else if (std::holds_alternative<uint1_t>(value)) { |
| 226 | + return ValueTag::U1; |
| 227 | + } else if (std::holds_alternative<uint16_t>(value)) { |
| 228 | + return ValueTag::U16; |
| 229 | + } else if (std::holds_alternative<uint32_t>(value)) { |
| 230 | + return ValueTag::U32; |
| 231 | + } else if (std::holds_alternative<uint64_t>(value)) { |
| 232 | + return ValueTag::U64; |
| 233 | + } else if (std::holds_alternative<uint128_t>(value)) { |
| 234 | + return ValueTag::U128; |
| 235 | + } else if (std::holds_alternative<FF>(value)) { |
| 236 | + return ValueTag::FF; |
| 237 | + } else { |
| 238 | + throw std::runtime_error("Unknown value type"); |
| 239 | + } |
| 240 | + |
| 241 | + assert(false && "This should never happen."); |
| 242 | + return ValueTag::FF; // Only to make the compiler happy. |
| 243 | +} |
| 244 | + |
| 245 | +std::string TaggedValue::to_string() const |
| 246 | +{ |
| 247 | + std::string v = std::visit( |
| 248 | + overloads{ [](const FF& val) -> std::string { return field_to_string(val); }, |
| 249 | + [](const uint128_t& val) -> std::string { return field_to_string(uint256_t::from_uint128(val)); }, |
| 250 | + [](const uint1_t& val) -> std::string { return val.value() == 0 ? "0" : "1"; }, |
| 251 | + [](auto&& val) -> std::string { return std::to_string(val); } }, |
| 252 | + value); |
| 253 | + return "TaggedValue(" + v + ", " + std::to_string(get_tag()) + ")"; |
| 254 | +} |
| 255 | + |
| 256 | +} // namespace bb::avm2 |
| 257 | + |
| 258 | +std::string std::to_string(bb::avm2::ValueTag tag) |
| 259 | +{ |
| 260 | + using namespace bb::avm2; |
| 261 | + switch (tag) { |
| 262 | + case ValueTag::U1: |
| 263 | + return "U1"; |
| 264 | + case ValueTag::U8: |
| 265 | + return "U8"; |
| 266 | + case ValueTag::U16: |
| 267 | + return "U16"; |
| 268 | + case ValueTag::U32: |
| 269 | + return "U32"; |
| 270 | + case ValueTag::U64: |
| 271 | + return "U64"; |
| 272 | + case ValueTag::U128: |
| 273 | + return "U128"; |
| 274 | + case ValueTag::FF: |
| 275 | + return "FF"; |
| 276 | + default: |
| 277 | + return "Unknown"; |
| 278 | + } |
| 279 | +} |
| 280 | + |
| 281 | +std::string std::to_string(const bb::avm2::TaggedValue& value) |
| 282 | +{ |
| 283 | + return value.to_string(); |
| 284 | +} |
0 commit comments