@@ -1494,32 +1494,15 @@ ggml_tensor * llm_graph_context::build_attn(
1494
1494
1495
1495
return cur;
1496
1496
}
1497
-
1498
- llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent () const {
1499
- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1500
-
1501
- auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1502
-
1503
- const auto n_kv = kv_state->get_n_kv ();
1504
-
1505
- auto & cur = inp->s_copy ;
1506
-
1507
- cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1508
- ggml_set_input (cur);
1509
-
1510
- return (llm_graph_input_rs *) res->add_input (std::move (inp));
1511
- }
1512
-
1513
- ggml_tensor * llm_graph_context::build_rs (
1514
- llm_graph_input_rs * inp,
1497
+ ggml_tensor * llm_graph_context::build_recurrent_state (
1498
+ const llama_kv_cache_recurrent_state * kv_state,
1515
1499
ggml_cgraph * gf,
1516
1500
ggml_tensor * s,
1501
+ ggml_tensor * state_copy,
1517
1502
int32_t state_size,
1518
1503
int32_t n_seqs,
1519
1504
bool avoid_copies) const {
1520
1505
1521
- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1522
-
1523
1506
const auto n_kv = kv_state->get_n_kv ();
1524
1507
const auto kv_head = kv_state->get_head ();
1525
1508
const auto rs_zero = kv_state->get_rs_z ();
@@ -1537,7 +1520,7 @@ ggml_tensor * llm_graph_context::build_rs(
1537
1520
// copy states
1538
1521
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1539
1522
// {state_size, kv_size} -> {state_size, n_seqs}
1540
- output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp-> s_copy , n_seqs, 0 ));
1523
+ output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy , n_seqs, 0 ));
1541
1524
ggml_build_forward_expand (gf, output_states);
1542
1525
} else {
1543
1526
// FIXME: make the gathering operation happen before the copy below
@@ -1546,7 +1529,7 @@ ggml_tensor * llm_graph_context::build_rs(
1546
1529
}
1547
1530
1548
1531
// copy extra states which won't be changed further (between n_seqs and n_kv)
1549
- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp-> s_copy , n_kv - n_seqs, n_seqs*inp-> s_copy ->nb [0 ]));
1532
+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy , n_kv - n_seqs, n_seqs*state_copy ->nb [0 ]));
1550
1533
ggml_build_forward_expand (gf,
1551
1534
ggml_cpy (ctx0,
1552
1535
states_extra,
@@ -1555,63 +1538,57 @@ ggml_tensor * llm_graph_context::build_rs(
1555
1538
return output_states;
1556
1539
}
1557
1540
1558
- llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent () const {
1559
- auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1560
- static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1541
+ llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent () const {
1542
+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1561
1543
1562
- const auto n_kv = inp->kv_state ->get_n_kv ();
1544
+ auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1545
+
1546
+ const auto n_kv = kv_state->get_n_kv ();
1563
1547
1564
1548
auto & cur = inp->s_copy ;
1565
1549
1566
1550
cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1567
1551
ggml_set_input (cur);
1568
1552
1569
- return (llm_graph_input_rs_hybrid_recurrent *) res->add_input (std::move (inp));
1553
+ return (llm_graph_input_rs *) res->add_input (std::move (inp));
1570
1554
}
1571
1555
1572
1556
ggml_tensor * llm_graph_context::build_rs (
1573
- llm_graph_input_rs_hybrid_recurrent * inp,
1557
+ llm_graph_input_rs * inp,
1574
1558
ggml_cgraph * gf,
1575
1559
ggml_tensor * s,
1576
1560
int32_t state_size,
1577
1561
int32_t n_seqs,
1578
1562
bool avoid_copies) const {
1579
1563
1580
- const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent ();
1564
+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1565
+ return build_recurrent_state (kv_state, gf, s, inp->s_copy , state_size, n_seqs, avoid_copies);
1566
+ }
1581
1567
1582
- const auto n_kv = kv_state-> get_n_kv ();
1583
- const auto kv_head = kv_state-> get_head ();
1584
- const auto rs_zero = kv_state-> get_rs_z ( );
1568
+ llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent () const {
1569
+ auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1570
+ static_cast < const llama_kv_cache_hybrid_recurrent_state *>(mstate) );
1585
1571
1586
- ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_state->get_size () );
1572
+ const auto n_kv = inp-> kv_state ->get_n_kv ( );
1587
1573
1588
- // Clear a single state which will then be copied to the other cleared states.
1589
- // Note that this is a no-op when the view is zero-sized.
1590
- ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
1591
- ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
1574
+ auto & cur = inp->s_copy ;
1592
1575
1593
- ggml_tensor * output_states;
1576
+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1577
+ ggml_set_input (cur);
1594
1578
1595
- if (!avoid_copies) {
1596
- // copy states
1597
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1598
- // {state_size, kv_size} -> {state_size, n_seqs}
1599
- output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 ));
1600
- ggml_build_forward_expand (gf, output_states);
1601
- } else {
1602
- // FIXME: make the gathering operation happen before the copy below
1603
- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1604
- output_states = states;
1605
- }
1579
+ return (llm_graph_input_rs_hybrid_recurrent *) res->add_input (std::move (inp));
1580
+ }
1606
1581
1607
- // copy extra states which won't be changed further (between n_seqs and n_kv)
1608
- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp->s_copy , n_kv - n_seqs, n_seqs*inp->s_copy ->nb [0 ]));
1609
- ggml_build_forward_expand (gf,
1610
- ggml_cpy (ctx0,
1611
- states_extra,
1612
- ggml_view_1d (ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size (s))));
1582
+ ggml_tensor * llm_graph_context::build_rs (
1583
+ llm_graph_input_rs_hybrid_recurrent * inp,
1584
+ ggml_cgraph * gf,
1585
+ ggml_tensor * s,
1586
+ int32_t state_size,
1587
+ int32_t n_seqs,
1588
+ bool avoid_copies) const {
1613
1589
1614
- return output_states;
1590
+ const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent ();
1591
+ return build_recurrent_state (kv_state, gf, s, inp->s_copy , state_size, n_seqs, avoid_copies);
1615
1592
}
1616
1593
1617
1594
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
0 commit comments