Skip to content

Hybrid recurrent cache #13904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
582792b
kv-cache : simplify the "struct llama_kv_cache" interface
ggerganov May 25, 2025
99653c3
kv-cache : revert the (n_swa + n_ubatch) change (for next PR)
ggerganov May 25, 2025
052f3f3
kv-cache : some comments
ggerganov May 25, 2025
5693eb6
context : fix graph reserve for multiple sequences
ggerganov May 25, 2025
cb2175f
kv-cache : fix typo [no ci]
ggerganov May 25, 2025
3c6b330
kv-cache : fix find_slot() logic for free slots
ggerganov May 25, 2025
f98b8d0
llama : add TODO for deprecating the defrag API in the future
ggerganov May 26, 2025
7e6d403
kv-cache : improve find_slot() using min/max seq pos info
ggerganov May 27, 2025
47e570c
llama : handle aborts and compute errors
ggerganov May 28, 2025
2b984f4
memory : extract state into llama_memory_state
ggerganov May 28, 2025
f23e4cc
kv-cache : add comments
ggerganov May 30, 2025
3fd6dd5
feat: Add llama_model_is_hybrid API call
gabe-l-hart May 9, 2025
dbad513
feat: Add c++ side constants for attention layer indices hparam
gabe-l-hart May 9, 2025
453d253
feat: Add support for distinguishing recurrent vs non-recurrent layer…
gabe-l-hart May 9, 2025
26e51f4
feat: Auto-fill hparams.recurrent_layer_arr based on whether the mode…
gabe-l-hart May 9, 2025
33a41f5
refactor: rename *_is_hybrid -> *_is_hybrid_recurrent
gabe-l-hart May 28, 2025
162639c
feat: Move layer_filter_cb up to llama_kv_cache
gabe-l-hart May 20, 2025
a886cc1
feat: Add layer filter to recurrent cache
gabe-l-hart May 20, 2025
5c149d2
fix: Fix indexing into k_l for recurrent cache with filter
gabe-l-hart May 20, 2025
4470221
fix: Use per-layer sizing everywhere in kv caches
gabe-l-hart May 14, 2025
ec7695f
feat: First pass at llama_kv_cache_hybrid_recurrent
gabe-l-hart May 30, 2025
728f514
feat: Construct hybrid recurrent cache for hybrid recurrent models
gabe-l-hart May 28, 2025
b58351e
fix: Fix wrong bool condition for split equal in hybrid cache
gabe-l-hart May 28, 2025
4a2709f
feat: Support hybrid recurrent cache in llm_graph_context
gabe-l-hart May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions include/llama.h
Original file line number Diff line number Diff line change
@@ -259,9 +259,9 @@ extern "C" {
llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
int8_t * logits; // TODO: rename this to "output"
} llama_batch;

enum llama_model_kv_override_type {
@@ -554,6 +554,9 @@ extern "C" {
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);

// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model);

// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
@@ -677,12 +680,14 @@ extern "C" {

// Returns the smallest position present in the KV cache for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
struct llama_context * ctx,
llama_seq_id seq_id);

// Returns the largest position present in the KV cache for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx,
@@ -692,12 +697,14 @@ extern "C" {
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);

// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);

// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);

//
23 changes: 23 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
@@ -144,6 +144,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },

{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -1747,3 +1748,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor);
}

bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}

bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
switch (arch) {
default:
return false;
}
}
4 changes: 4 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
@@ -148,6 +148,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_LAYER_INDICES,

LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -437,3 +438,6 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name);

const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);

bool llm_arch_is_recurrent(const llm_arch& arch);
bool llm_arch_is_hybrid_recurrent(const llm_arch& arch);
31 changes: 19 additions & 12 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
@@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
break;
}
}
ubatch_token.resize(!has_embd ? n_ubatch : 0);
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
ubatch_pos.resize(n_ubatch);
ubatch_n_seq_id.resize(n_ubatch);
ubatch_seq_id.resize(n_ubatch);
ubatch_output.resize(n_ubatch);

udatas.push_back({});

auto & udata = udatas.back();

udata.token.resize(!has_embd ? n_ubatch : 0);
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
udata.pos.resize(n_ubatch);
udata.n_seq_id.resize(n_ubatch);
udata.seq_id.resize(n_ubatch);
udata.output.resize(n_ubatch);

llama_ubatch ubatch = {
/*equal_seqs =*/ true,
/*n_tokens =*/ 0,
/*n_seq_tokens =*/ 0,
/*n_seqs =*/ 0,
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
/*pos =*/ ubatch_pos.data(),
/*n_seq_id =*/ ubatch_n_seq_id.data(),
/*seq_id =*/ ubatch_seq_id.data(),
/*output =*/ ubatch_output.data(),
/*token =*/ !has_embd ? udata.token.data() : nullptr,
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
/*pos =*/ udata.pos.data(),
/*n_seq_id =*/ udata.n_seq_id.data(),
/*seq_id =*/ udata.seq_id.data(),
/*output =*/ udata.output.data(),
};

return ubatch;
}

25 changes: 15 additions & 10 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
@@ -11,15 +11,15 @@ struct llama_ubatch {
bool equal_seqs;
// TODO: whole_seqs for embeddings?

uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence
uint32_t n_seqs;

llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
int8_t * output; // [n_tokens]
};

@@ -49,13 +49,18 @@ struct llama_sbatch {

const llama_batch * batch = nullptr;

// buffers for the ubatch
std::vector<llama_token> ubatch_token;
std::vector<float> ubatch_embd;
std::vector<llama_pos> ubatch_pos;
std::vector<int32_t> ubatch_n_seq_id;
std::vector<llama_seq_id *> ubatch_seq_id;
std::vector<int8_t> ubatch_output;
// buffers for the ubatches
// TODO: very hacky, this needs a complete rework
struct ubatch_data {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;
};

std::vector<ubatch_data> udatas;

llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);

Loading