Skip to content

Commit faf4119

Browse files
committed
refactor: Use a common build_recurrent_state method that is cache-agnostic
This reduces the code duplication between the different build_rs impls and also retains a similar signature to the previous build_recurrent_state method while standardizing on the input-dispatched build_rs implementation. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 5046d41 commit faf4119

File tree

2 files changed

+42
-56
lines changed

2 files changed

+42
-56
lines changed

src/llama-graph.cpp

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,32 +1494,15 @@ ggml_tensor * llm_graph_context::build_attn(
14941494

14951495
return cur;
14961496
}
1497-
1498-
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
1499-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1500-
1501-
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1502-
1503-
const auto n_kv = kv_state->get_n_kv();
1504-
1505-
auto & cur = inp->s_copy;
1506-
1507-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1508-
ggml_set_input(cur);
1509-
1510-
return (llm_graph_input_rs *) res->add_input(std::move(inp));
1511-
}
1512-
1513-
ggml_tensor * llm_graph_context::build_rs(
1514-
llm_graph_input_rs * inp,
1497+
ggml_tensor * llm_graph_context::build_recurrent_state(
1498+
const llama_kv_cache_recurrent_state * kv_state,
15151499
ggml_cgraph * gf,
15161500
ggml_tensor * s,
1501+
ggml_tensor * state_copy,
15171502
int32_t state_size,
15181503
int32_t n_seqs,
15191504
bool avoid_copies) const {
15201505

1521-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1522-
15231506
const auto n_kv = kv_state->get_n_kv();
15241507
const auto kv_head = kv_state->get_head();
15251508
const auto rs_zero = kv_state->get_rs_z();
@@ -1537,7 +1520,7 @@ ggml_tensor * llm_graph_context::build_rs(
15371520
// copy states
15381521
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
15391522
// {state_size, kv_size} -> {state_size, n_seqs}
1540-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
1523+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
15411524
ggml_build_forward_expand(gf, output_states);
15421525
} else {
15431526
// FIXME: make the gathering operation happen before the copy below
@@ -1546,7 +1529,7 @@ ggml_tensor * llm_graph_context::build_rs(
15461529
}
15471530

15481531
// copy extra states which won't be changed further (between n_seqs and n_kv)
1549-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
1532+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
15501533
ggml_build_forward_expand(gf,
15511534
ggml_cpy(ctx0,
15521535
states_extra,
@@ -1555,63 +1538,57 @@ ggml_tensor * llm_graph_context::build_rs(
15551538
return output_states;
15561539
}
15571540

1558-
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
1559-
auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1560-
static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1541+
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
1542+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
15611543

1562-
const auto n_kv = inp->kv_state->get_n_kv();
1544+
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1545+
1546+
const auto n_kv = kv_state->get_n_kv();
15631547

15641548
auto & cur = inp->s_copy;
15651549

15661550
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
15671551
ggml_set_input(cur);
15681552

1569-
return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp));
1553+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
15701554
}
15711555

15721556
ggml_tensor * llm_graph_context::build_rs(
1573-
llm_graph_input_rs_hybrid_recurrent * inp,
1557+
llm_graph_input_rs * inp,
15741558
ggml_cgraph * gf,
15751559
ggml_tensor * s,
15761560
int32_t state_size,
15771561
int32_t n_seqs,
15781562
bool avoid_copies) const {
15791563

1580-
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
1564+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1565+
return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies);
1566+
}
15811567

1582-
const auto n_kv = kv_state->get_n_kv();
1583-
const auto kv_head = kv_state->get_head();
1584-
const auto rs_zero = kv_state->get_rs_z();
1568+
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
1569+
auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1570+
static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
15851571

1586-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
1572+
const auto n_kv = inp->kv_state->get_n_kv();
15871573

1588-
// Clear a single state which will then be copied to the other cleared states.
1589-
// Note that this is a no-op when the view is zero-sized.
1590-
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1591-
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1574+
auto & cur = inp->s_copy;
15921575

1593-
ggml_tensor * output_states;
1576+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1577+
ggml_set_input(cur);
15941578

1595-
if (!avoid_copies) {
1596-
// copy states
1597-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1598-
// {state_size, kv_size} -> {state_size, n_seqs}
1599-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
1600-
ggml_build_forward_expand(gf, output_states);
1601-
} else {
1602-
// FIXME: make the gathering operation happen before the copy below
1603-
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1604-
output_states = states;
1605-
}
1579+
return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp));
1580+
}
16061581

1607-
// copy extra states which won't be changed further (between n_seqs and n_kv)
1608-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
1609-
ggml_build_forward_expand(gf,
1610-
ggml_cpy(ctx0,
1611-
states_extra,
1612-
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1582+
ggml_tensor * llm_graph_context::build_rs(
1583+
llm_graph_input_rs_hybrid_recurrent * inp,
1584+
ggml_cgraph * gf,
1585+
ggml_tensor * s,
1586+
int32_t state_size,
1587+
int32_t n_seqs,
1588+
bool avoid_copies) const {
16131589

1614-
return output_states;
1590+
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
1591+
return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies);
16151592
}
16161593

16171594
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(

src/llama-graph.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,15 @@ struct llm_graph_context {
622622
// recurrent
623623
//
624624

625+
ggml_tensor * build_recurrent_state(
626+
const llama_kv_cache_recurrent_state * kv_state,
627+
ggml_cgraph * gf,
628+
ggml_tensor * s,
629+
ggml_tensor * state_copy,
630+
int32_t state_size,
631+
int32_t n_seqs,
632+
bool avoid_copies = false) const;
633+
625634
llm_graph_input_rs * build_rs_inp_recurrent() const;
626635

627636
ggml_tensor * build_rs(

0 commit comments

Comments
 (0)