Skip to content

Hybrid recurrent cache #13979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
b9e8f94
feat: Add llama_model_is_hybrid API call
gabe-l-hart May 9, 2025
897dce5
feat: Add c++ side constants for attention layer indices hparam
gabe-l-hart May 9, 2025
eb6c979
feat: Add support for distinguishing recurrent vs non-recurrent layer…
gabe-l-hart May 9, 2025
ba362dc
feat: Auto-fill hparams.recurrent_layer_arr based on whether the mode…
gabe-l-hart May 9, 2025
a391645
refactor: rename *_is_hybrid -> *_is_hybrid_recurrent
gabe-l-hart May 28, 2025
e00e525
feat: Add layer filter to recurrent cache
gabe-l-hart May 20, 2025
f0cb485
fix: Use per-layer sizing everywhere in kv caches
gabe-l-hart May 14, 2025
6a822b7
feat: First pass at llama_kv_cache_hybrid_recurrent
gabe-l-hart May 30, 2025
fe39803
feat: Construct hybrid recurrent cache for hybrid recurrent models
gabe-l-hart May 28, 2025
8cb41a9
fix: Fix wrong bool condition for split equal in hybrid cache
gabe-l-hart May 28, 2025
980ae73
fix: Fix shift logic to defer to unified cache
gabe-l-hart Jun 3, 2025
31be8ae
feat: Support hybrid recurrent in llama-graph
gabe-l-hart Jun 4, 2025
2e9e969
fix: Fix logic for initializing inputs and attn layers for hybrid caches
gabe-l-hart Jun 4, 2025
b576579
fix: Use @compilade's suggested fix for seq_id indexing with equal sp…
gabe-l-hart May 28, 2025
c84f607
fix: Update recurrent cache for changes to remove intermediate kv_cac…
gabe-l-hart Jun 5, 2025
ecaac6b
fix: Fix status for init_update sig for recurrent cache state
gabe-l-hart Jun 5, 2025
f69d82c
fix: Add missing padding to n_ctx for hybrid cache construction
gabe-l-hart Jun 5, 2025
60ca3ba
fix: Update clear signature for data argument after rebase
gabe-l-hart Jun 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,9 @@ extern "C" {
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);

// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model);

// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_library(llama
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-kv-cache-recurrent.cpp
llama-kv-cache-hybrid-recurrent.cpp
llama-memory.cpp
llama-mmap.cpp
llama-model-loader.cpp
Expand Down
23 changes: 23 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },

{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
Expand Down Expand Up @@ -1752,3 +1753,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor);
}

bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}

bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
switch (arch) {
default:
return false;
}
}
4 changes: 4 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_LAYER_INDICES,

LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
Expand Down Expand Up @@ -436,3 +437,6 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name);

const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);

bool llm_arch_is_recurrent(const llm_arch& arch);
bool llm_arch_is_hybrid_recurrent(const llm_arch& arch);
49 changes: 39 additions & 10 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache-recurrent.h"
#include "llama-kv-cache-hybrid-recurrent.h"

#include <cassert>
#include <cmath>
Expand Down Expand Up @@ -969,8 +970,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
return cur;
}

ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
if (kv_state == nullptr) {
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
}

auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);

Expand Down Expand Up @@ -1284,7 +1287,9 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);

const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one
// encapsulated in inp
const auto * kv_state = inp->kv_state;

// store to KV cache
{
Expand Down Expand Up @@ -1316,6 +1321,26 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}

llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);

auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn());

{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");

const auto n_kv = kv_state->get_state_attn()->get_n_kv();

inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}

return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
}

llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);

Expand Down Expand Up @@ -1456,13 +1481,17 @@ ggml_tensor * llm_graph_context::build_attn(
}

ggml_tensor * llm_graph_context::build_copy_mask_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs,
const llama_kv_cache_recurrent_state * kv_state) const {

if (kv_state == nullptr) {
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
}

const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
Expand Down
20 changes: 12 additions & 8 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct llama_memory_state_i;
class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state;
class llama_kv_cache_recurrent_state;
class llama_kv_cache_hybrid_recurrent_state;

// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
Expand Down Expand Up @@ -254,7 +255,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
cparams(cparams),
kv_state(kv_state) {
}
~llm_graph_input_attn_kv_unified() = default;
virtual ~llm_graph_input_attn_kv_unified() = default;

void set_input(const llama_ubatch * ubatch) override;

Expand Down Expand Up @@ -520,7 +521,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
ggml_tensor * build_inp_s_mask() const;

ggml_tensor * build_inp_cross_embd() const;
Expand Down Expand Up @@ -587,6 +588,8 @@ struct llm_graph_context {
float kq_scale,
int il) const;

llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const;

llm_graph_input_attn_cross * build_attn_inp_cross() const;

ggml_tensor * build_attn(
Expand All @@ -607,12 +610,13 @@ struct llm_graph_context {
//

ggml_tensor * build_copy_mask_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const;
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs,
const llama_kv_cache_recurrent_state * kv_state = nullptr) const;

ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
Expand Down
14 changes: 12 additions & 2 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv;
}

uint32_t llama_hparams::n_embd_k_s() const {
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
if (!recurrent_layer(il)) {
return 0;
}
if (wkv_head_size != 0) {
// for RWKV models
return token_shift_count * n_embd;
Expand All @@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const {
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
}

uint32_t llama_hparams::n_embd_v_s() const {
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
if (!recurrent_layer(il)) {
return 0;
}
if (wkv_head_size != 0) {
// corresponds to RWKV's wkv_states size
return n_embd * wkv_head_size;
Expand All @@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
return ssm_d_state * ssm_d_inner;
}

bool llama_hparams::recurrent_layer(uint32_t il) const {
return recurrent_layer_arr[il];
}

bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return swa_layers[il];
Expand Down
10 changes: 8 additions & 2 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ struct llama_hparams {
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;

// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;

bool ssm_dt_b_c_rms = false;

float f_clamp_kqv = 0.0f;
Expand Down Expand Up @@ -181,10 +184,13 @@ struct llama_hparams {

// dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_k_s() const;
uint32_t n_embd_k_s(uint32_t il = 0) const;

// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;
uint32_t n_embd_v_s(uint32_t il = 0) const;

// whether or not the given layer is recurrent (for hybrid models)
bool recurrent_layer(uint32_t il) const;

bool is_swa(uint32_t il) const;
};
Expand Down
Loading
Loading