Skip to content

fuzz-tests: add test for amount-{sat, msat} arithmetic #8298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/amount.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <common/overflows.h>
#include <common/utils.h>
#include <inttypes.h>
#include <stdio.h>
#include <wire/wire.h>

bool amount_sat_to_msat(struct amount_msat *msat,
Expand Down Expand Up @@ -616,6 +617,12 @@ struct amount_msat amount_msat_sub_fee(struct amount_msat in,
* Since we round the fee down, out can be a bit bigger than
* expected, so we iterate upwards.
*/
FILE *logf = fopen("/tmp/amount_msat_sub_fee.log", "a");
if (logf) {
fprintf(logf, "CALL: in=%" PRIu64 "msat, fee_base_msat=%u, fee_ppm=%u\n",
in.millisatoshis, fee_base_msat, fee_proportional_millionths);
fclose(logf);
}
if (!amount_msat_sub(&out, in, amount_msat(fee_base_msat)))
return AMOUNT_MSAT(0);
if (!amount_msat_mul_div(&out, out, 1000000,
Expand Down
270 changes: 270 additions & 0 deletions tests/fuzz/fuzz-amount-arith.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
#include "config.h"
#include <assert.h>
#include <math.h>
#include <common/amount.h>
#include <common/overflows.h>
#include <tests/fuzz/libfuzz.h>

void init(int *argc, char ***argv) {}

enum op {
OP_MSAT_ADD,
OP_MSAT_SUB,
OP_MSAT_MUL,
OP_MSAT_DIV,
OP_MSAT_RATIO,
OP_MSAT_RATIO_FLOOR,
OP_MSAT_RATIO_CEIL,
OP_MSAT_SCALE,
OP_MSAT_ADD_SAT,
OP_MSAT_SUB_SAT,
OP_SAT_ADD,
OP_SAT_SUB,
OP_SAT_MUL,
OP_SAT_DIV,
OP_SAT_SCALE,
OP_FEE,
OP_ADD_FEE,
OP_SUB_FEE,
OP_TX_FEE,
OP_FEERATE,
OP_COUNT
};

void run(const uint8_t *data, size_t size) {
if (size < sizeof(uint8_t) + 2 * sizeof(struct amount_msat) + sizeof(double))
return;

uint8_t op = *data++ % OP_COUNT;

struct amount_msat a = fromwire_amount_msat(&data, &size);
struct amount_msat b = fromwire_amount_msat(&data, &size);

double f;
memcpy(&f, data, sizeof(f));
data += sizeof(f);

struct amount_sat sa = amount_msat_to_sat_round_down(a);
struct amount_sat sb = amount_msat_to_sat_round_down(b);

u64 u64_param;
memcpy(&u64_param, &f, sizeof(u64_param));

struct amount_msat out_ms;
struct amount_sat out_s;

switch (op) {
case OP_MSAT_ADD:
{
if (amount_msat_add(&out_ms, a, b)) {
assert(out_ms.millisatoshis == a.millisatoshis + b.millisatoshis);
}
break;
}

case OP_MSAT_SUB:
{
if (amount_msat_sub(&out_ms, a, b)) {
assert(out_ms.millisatoshis + b.millisatoshis == a.millisatoshis);
}
break;
}

case OP_MSAT_MUL:
{
if (amount_msat_mul(&out_ms, a, u64_param)) {
assert(out_ms.millisatoshis == a.millisatoshis * u64_param);
}
break;
}

case OP_MSAT_DIV:
{
if (u64_param == 0)
break;
out_ms = amount_msat_div(a, u64_param);
assert(out_ms.millisatoshis == a.millisatoshis / u64_param);
break;
}

case OP_MSAT_RATIO:
{
if (b.millisatoshis == 0)
break;
double ratio = amount_msat_ratio(a, b);
double expected = (double)a.millisatoshis / b.millisatoshis;
assert(ratio == expected);
break;
}

case OP_MSAT_RATIO_FLOOR:
{
if (b.millisatoshis == 0)
break;
u64 floor = amount_msat_ratio_floor(a, b);
assert(floor == a.millisatoshis / b.millisatoshis);
break;
}

case OP_MSAT_RATIO_CEIL:
{
if (b.millisatoshis == 0)
break;

// The assertion remains valid ONLY if there's no overflow
if (a.millisatoshis > UINT64_MAX - b.millisatoshis + 1) {
break;
}

u64 ceil = amount_msat_ratio_ceil(a, b);
u64 quotient = a.millisatoshis / b.millisatoshis;
u64 remainder = a.millisatoshis % b.millisatoshis;

assert(ceil == quotient + (remainder != 0));
break;
}

case OP_MSAT_SCALE:
{
// if (amount_msat_scale(&out_ms, a, f)) {
// double expect = (double)a.millisatoshis * f;
// assert(fabs((double)out_ms.millisatoshis - expect) < 1.0);
// }
break;
}

case OP_MSAT_ADD_SAT:
{
if (amount_msat_add_sat(&out_ms, a, sa)) {
assert(out_ms.millisatoshis == sa.satoshis * MSAT_PER_SAT + a.millisatoshis);
}
break;
}

case OP_MSAT_SUB_SAT:
{
if (amount_msat_sub_sat(&out_ms, a, sa)) {
assert(out_ms.millisatoshis + sa.satoshis * MSAT_PER_SAT == a.millisatoshis);
}
break;
}

case OP_SAT_ADD:
{
if (amount_sat_add(&out_s, sa, sb)) {
assert(out_s.satoshis == sa.satoshis + sb.satoshis);
}
break;
}

case OP_SAT_SUB:
{
if (amount_sat_sub(&out_s, sa, sb)) {
assert(out_s.satoshis == sa.satoshis - sb.satoshis);
}
break;
}

case OP_SAT_MUL:
{
if (amount_sat_mul(&out_s, sa, u64_param)) {
assert(out_s.satoshis == sa.satoshis * u64_param);
}
break;
}

case OP_SAT_DIV:
{
if (u64_param == 0)
break;
out_s = amount_sat_div(sa, u64_param);
assert(out_s.satoshis == sa.satoshis / u64_param);
break;
}

case OP_SAT_SCALE:
{
// if (amount_sat_scale(&out_s, sa, f)) {
// double expect = sa.satoshis * f;
// assert(fabs((double)out_s.satoshis - expect) < 1.0);
// }
break;
}

case OP_FEE:
{
if (amount_msat_fee(&out_ms, a, (u32)(a.millisatoshis & UINT32_MAX), (u32)(b.millisatoshis & UINT32_MAX))) {
assert(out_ms.millisatoshis >= (a.millisatoshis & UINT32_MAX));
}
break;
}

case OP_ADD_FEE:
{
u32 fee_base = (u32)(a.millisatoshis & UINT32_MAX);
u32 fee_prop = (u32)(b.millisatoshis & UINT32_MAX);

struct amount_msat original = a;
struct amount_msat fee;

if (amount_msat_fee(&fee, original, fee_base, fee_prop)) {
struct amount_msat total;
if (amount_msat_add(&total, original, fee)) {
assert(amount_msat_greater_eq(total, fee));

struct amount_msat expected_total;
assert(amount_msat_add(&expected_total, original, fee));
assert(amount_msat_eq(total, expected_total));

a = total;
}
}
}

case OP_SUB_FEE:
{
u32 fee_base = (u32)(a.millisatoshis & UINT32_MAX);
u32 fee_prop = (u32)(b.millisatoshis & UINT32_MAX);
struct amount_msat input = a;
struct amount_msat output = amount_msat_sub_fee(input, fee_base, fee_prop);
struct amount_msat fee;
if (amount_msat_fee(&fee, output, fee_base, fee_prop)) {
struct amount_msat sum;
if (amount_msat_add(&sum, output, fee))
assert(amount_msat_less_eq(sum, input));
}
break;
}

case OP_TX_FEE:
{
if (b.millisatoshis > SIZE_MAX)
break;
u32 fee_per_kw = (u32)(a.millisatoshis & UINT32_MAX);
size_t weight = (size_t)(b.millisatoshis);

/* weights > 2^32 are not real tx and hence, discarded */
if (mul_overflows_u64(fee_per_kw, weight))
break;
struct amount_sat fee = amount_tx_fee(fee_per_kw, weight);
u64 expected = (fee_per_kw * weight) / MSAT_PER_SAT;
assert(fee.satoshis == expected);
break;
}

case OP_FEERATE:
{
struct amount_sat fee = amount_msat_to_sat_round_down(a);
size_t weight = (size_t)(b.millisatoshis);
u32 feerate;
if (weight && amount_feerate(&feerate, fee, weight)) {
u64 expected = (fee.satoshis * MSAT_PER_SAT) / weight;
assert(feerate == expected);
}
break;
}

default:
assert(false && "unknown operation");
}
}
38 changes: 38 additions & 0 deletions tests/test_askrene.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,44 @@ def direction(src, dst):
return 1


def test_fee_overflow(node_factory):
"""Test for integer overflow in fee calculation with extreme parameters"""
input = 1000
fee_base = 8
fee_prop = 4295000

# Setup: Create a line graph with 2 nodes
l1, l2 = node_factory.line_graph(2, wait_for_announce=True)

# Create a new layer for fee modifications
l1.rpc.askrene_create_layer('fee_update_layer')

# Get channel ID between l1 and l2 (example)
scid = first_scid(l1, l2)
scid_dir = f"{scid}/{direction(l1.info['id'], l2.info['id'])}"

# Update fee parameters for a specific channel direction
l1.rpc.askrene_update_channel(
layer='fee_update_layer',
short_channel_id_dir=scid_dir,
htlc_minimum_msat=100,
# Any values larger than these get truncated to a set of safe values
htlc_maximum_msat='9999999999999999sat',
fee_base_msat=9999999,
fee_proportional_millionths=999999,
cltv_expiry_delta=18
)

# Now call getroutes with the modified fee layer
routes = l1.rpc.getroutes(
source=l1.info['id'],
destination=l2.info['id'],
amount_msat=input,
layers=['fee_update_layer'],
maxfee_msat=0xFFFFFFFFFFFFFFFF,
final_cltv=1440
)

def test_reserve(node_factory):
"""Test reserving channels"""
l1, l2, l3 = node_factory.line_graph(3, wait_for_announce=True)
Expand Down