Skip to content

Hybrid recurrent cache #13979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from

Conversation

gabe-l-hart
Copy link
Contributor

This is a re-opened version of #13904 after #13746 was merged

Description

This PR introduces the llama_kv_cache_hybrid_recurrent cache implementation. It follows the pattern of llama_kv_cache_unified_iswa by holding two child cache instances and implementing the interface logic such that it manages both correctly for the appropriate layers.

Changes

The main change in this PR is the addition of llama_kv_cache_hybrid_recurrent in llama-kv-cache-hybrid-recurrent.*. In addition to this, the PR does the following:

  • Add the llama_model_is_hybrid_recurrent public API (akin to llama_model_is_recurrent)
  • Add LLM_KV_ATTENTION_LAYER_INDICES as an hparam to hold the indices of the layers that should use attention (versus recurrent)
    • This part is not well aligned with iswa, but that mechanism also isn't particularly extensible. It might be more appropriate to have a generic mechanism for indicating the type of caching to use for each layer, but that would start to approach the generic hybrid implementation that I originally attempted which ended up being too abstract (feat: Hybrid unified/recurrent cache #13276).
  • Abstracting utilities in llm_graph_context that need a specific type of cache to use getters (get_state_unified / get_state_recurrent) that will properly handle llama_kv_cache_hybrid_recurrent
  • Make n_embd_k_s / n_embd_v_s layer-dependent and use layer indices when calling them in the existing cache implementations
  • Add layer filtering to llama_kv_cache_recurrent
  • Updates the logic in llama_model::create_memory to use llm_arch_is_recurrent and llm_arch_is_hybrid_recurrent rather than relying on adding models to the switch statement which was redundant with the implementation of these functions

@gabe-l-hart gabe-l-hart mentioned this pull request Jun 2, 2025
1 task
Comment on lines 1063 to 1086
const llama_kv_cache_unified_state * llm_graph_context::get_state_unified() const {
const auto * umstate = dynamic_cast<const llama_kv_cache_unified_state *>(mstate);
if (!umstate) {
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
if (hmstate) {
umstate = hmstate->get_state_attn();
}
}
GGML_ASSERT(umstate);
return umstate;
}

const llama_kv_cache_recurrent_state * llm_graph_context::get_state_recurrent() const {
const auto * rmstate = dynamic_cast<const llama_kv_cache_recurrent_state *>(mstate);
if (!rmstate) {
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
if (hmstate) {
rmstate = hmstate->get_state_recurrent();
}
}
GGML_ASSERT(rmstate);
return rmstate;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These dynamic casts should not be necessary. Instead you need a new llm_graph_context::build_attn_inp_kv_hybrid_recurrent() method, similar to llm_graph_context::build_attn_inp_kv_unified_iswa().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working through this now and a couple of questions are coming up:

  1. Would it be best to combine build_inp_s_copy with build_attn_inp_kv for hybrid so that models call just one "build inputs" function, or keep them separate for simplicity?
  2. For the build_attn methods, each has a corresponding llm_graph_input_attn_* class. The build_inp_s_* methods don't have this pattern which would make this a bit harder to have code reuse. Are there plans to refactor that further @compilade?
  3. In the mamba2 branch, s_mask seems to be totally removed. I'd prefer not to do all of the boilerplate for duplicating build_inp_s_mask for the hybrid recurrent case if that's definitely going to be going away. Is there any reason that might stick around past the merge of mamba2?

Copy link
Collaborator

@compilade compilade Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answering out of order, but it should still make sense:

2. For the build_attn methods, each has a corresponding llm_graph_input_attn_* class. The build_inp_s_* methods don't have this pattern

They do follow this pattern, see

class llm_graph_input_s_copy : public llm_graph_input_i {

(this is on the current master)

I think you might mean the build_attn_* methods also return instances of llm_graph_input_attn_*?
That seems to be directly related to llm_graph_context::build_attn() which has multiple implementations which differ by the type of the first argument (e.g. for llm_graph_input_attn_kv_unified, llm_graph_input_attn_no_cache, etc.)

Are there plans to refactor that further @compilade?

Not really, outside of removing s_mask (and related functions and classes) as part of #13834.

  1. Would it be best to combine build_inp_s_copy with build_attn_inp_kv for hybrid so that models call just one "build inputs" function, or keep them separate for simplicity?

Personally, I think it would be simpler to keep them separate, because they are fundamentally different (one is intended to be used by build_copy_mask_state (renamed to build_recurrent_state in #13834), while the other is used by build_attn), and they are pretty much independent, even in hybrid models (at least for Jamba, the recurrent and self-attention layers are mostly independent on that front).

I don't see how build_attn would ever need s_copy.

build_inp_s_copy and build_inp_attn_kv_* are called once at the beginning of the graph, while build_attn and build_recurrent_state are called once per layer (where applicable, and so usually different layers for both).

3. Is there any reason [s_mask] might stick around past the merge of mamba2?

No reason to keep it, s_mask will be removed. Its functionality is redundant with s_copy, and otherwise prevents minimizing unnecessary state copies. It was used to clear the states, but the same can be done through inp_s_copy and clearing by copying a zero-ed state (which is the rs_z'th state in the mamba2 branch (and #13834)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's super helpful! I was missing the distinction between build_attn_inp and build_attn which makes perfect sense.

Personally, I think it would be simpler to keep them separate

I agree on my personal gut feeling, so I'll go with this and see how it feels once complete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think this feels a lot cleaner now. For build_inp_s_copy, I opted to add an optional parameter so that the caller can take ownership of casting the cache state rather than duplicating the function into build_inp_s_copy_hybrid. That felt a little cleaner w.r.t. code reuse, but I'm happy to do a separate method if that's preferred.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's one more place that will need changing in build_copy_mask_state (renamed to build_recurrent_state on mamba2). Similar to build_inp_s_copy, I think the cleanest way to do this for code reuse is to add an optional parameter that, if unset, will use the current logic of casting mstate.

Comment on lines +118 to +75
// TODO: will the recurrent cache be in an undefined state at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but that will be fixed in #13834

(Noting here in case this gets merged first so that I don't forget to update the comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I strip out this TODO at this point?

@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch 7 times, most recently from ab918bb to 60ca3ba Compare June 9, 2025 16:06
@gabe-l-hart
Copy link
Contributor Author

@ggerganov I've noticed that the Server tests are consistently failing on this branch, but I see them passing on other PRs (eg #14081). The failures seem to be around retried connections to /slots/0?action=restore (here). I've noticed a couple of things when trying to repro locally:

  1. For all of the server tests, despite the health check ping loop, I need to manually inject a sleep after server.start() to get any of the tests to pass on my M3 Max 64GB, otherwise the first endpoint call will almost always return an unexpected result (invalid json, or 5xx error). To me, this speaks to the server reporting /health success before it's logically "ready." This is probably correct in that the server is healthy enough to take requests, but the tests are treating this as a readiness check which would indicate that the server is fully initialized which seems incorrect (similar to liveness vs readiness in kubernetes). I wasn't able to find any open issues related to this, so I'd be curious if there's a discussion somewhere of readiness vs liveness?

  2. Once I do insert the manual sleep in test_slot_save_restore, the test passes locally for me. This speaks to some difference between my local system and the GH Action runners. I'm not familiar enough with the server code to know whether this would be caused by some kind of memory constraint, disk speed constraint, etc, but my guess would be something like slot1.bin not being fully written somehow. This wouldn't explain why it passes on other branches though. Any insight on the nature of these tests would be much appreciated as I try to debug!

gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 9, 2025
This is an attempt to handle race conditions between /health returning OK
and the other endpoints not returning successfully.

ggml-org#13979 (comment)

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart gabe-l-hart requested a review from ngxson as a code owner June 9, 2025 17:36
@github-actions github-actions bot added examples python python script changes server labels Jun 9, 2025
@gabe-l-hart
Copy link
Contributor Author

I've tried adding retry logic for all requests in 39a93b3 to work around the race between /health, but I'm not sure if this is just going to mask an underling issue.

@ggerganov
Copy link
Member

The changes to the server tests should not be needed. Let's revert the commit for now and I'll investigate.

@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch from 39a93b3 to 60ca3ba Compare June 9, 2025 20:13
@gabe-l-hart
Copy link
Contributor Author

@ggerganov Thanks, it looks like those changes didn't fix the failures anyway, so definitely not the right fix. I've reset them out and can open an issue with details of what I see locally

@gabe-l-hart
Copy link
Contributor Author

Issue for follow up on /health race condition with tests: #14092

@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch 2 times, most recently from 7958d84 to 3669876 Compare June 10, 2025 21:22
@gabe-l-hart gabe-l-hart marked this pull request as draft June 10, 2025 21:22
@gabe-l-hart
Copy link
Contributor Author

I've rebased on #13834. Drafting for now until it's merged

@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch from 3669876 to 8c59841 Compare June 10, 2025 22:22
@gabe-l-hart gabe-l-hart marked this pull request as ready for review June 10, 2025 22:22
@gabe-l-hart
Copy link
Contributor Author

That was quick! Undrafting now that #13834 is merged

@@ -242,7 +243,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
cparams(cparams),
kv_state(kv_state) {
}
~llm_graph_input_attn_kv_unified() = default;
virtual ~llm_graph_input_attn_kv_unified() = default;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this virtual is needed. I think I added it when I was attempting to have the hybrid input inherit from the unified input

@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch from bb87dbf to b216ed3 Compare June 11, 2025 13:09
uint32_t n_seq_max,
bool offload) :
hparams(model.hparams),
kv_attn(new llama_kv_cache_unified(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gabe-l-hart - thanks for your huge work. Chiming in this PR to ask a question: how do you think we should handle cache formats where for every layer we have both the attention and SSM states ? This design seems to be tailored for architectures that have either SSM or Attention for each decoder block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @younesbelkada! That sounds like a really interesting architecture. Can you point to any papers or draft models that use this approach? I'd be really curious to check them out. The main advantage of the hybrid recurrent architecture where each layer is either SSM or Attention (but not both) is that you get the cache size scaling advantages of SSMs for most layers without sacrificing the ability to attend to all tokens for the whole model.

My first pass at a standalone hybrid cache (#13276) attempted to be a more generic representation of "each layer can have its own cache type" where the parent cache could have any number of children. This ran into a number of challenges during the recent interface refactoring, and ultimately was incompatible with some of the work to encapsulate the interactions with the cache, so at Georgi's suggestion, I pivoted to this less-abstract implementation.

Following the pattern established by iswa and now hybrid_recurrent, I would think implementing such a cache would look like:

  1. Creating another new subclass of llama_memory_i that owns the child caches (similar to hybrid_recurrent)
    • The various interfaces would probably be implemented with "AND" semantics instead of "OR" semantics for child layers
  2. Creating the corresponding subclass of llama_memory_state_i
  3. Use a different set of layer filters when constructing the parent cache with the children such that the children apply to the right layers (which are now overlapping sets)

The place where I see the most incompatibility is in the hparams.recurrent_layer() method which right now is a "yes/no" and if true implies ! attention_layer (explicitly in the n_embd_k_s / n_embd_v_s methods). One possible solution here would be for the hparams to hold some kind of get_cache_type(il) -> set{CACHE_TYPE} method that would return a set of cache type enum values for a given layer index. This might also be able to replace the llm_arch_is_recurrent / llm_arch_is_hybrid_recurrent interfaces.

@compilade @ggerganov do you foresee the need to allow a single layer to use multiple cache types as something we'll want to support in the future? If so, we might want to implement that more-generic interface in hparams now to avoid an API breaking change down the road.

Copy link
Contributor

@younesbelkada younesbelkada Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @gabe-l-hart for your detailed response 🙏
Currently there is an effort to try to integrate Falcon-H1 models into main llama.cpp: https://huggingface.co/collections/tiiuae/falcon-h1-6819f2795bc406da60fab8df (for more details you can check the technical blogpost here, this architecture employs a parallel hybrid design where you have both SSM and Attention on each decoder block) - we have a public fork of llama.cpp that we are currently maintaining here: https://github.com/tiiuae/llama.cpp-Falcon-H1 but we are refactoring the changes we made so that it will be compatible with the changes on master (including this PR and #9126) and merge the changes in llama.cpp master (related issue: #13681)
Right now we tried to build on top of this PR and the Mamba2 PR from @compilade and faced this issue so we are not sure what is the approach to take. What you suggested sounds good - we will be happy to work on whatever you think is the best suggestion!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I actually think someone pointed me at your fork a week ago and I lost it in my exploded number of tabs. I'll definitely check it out and see if there are any places this should tweak to accommodate the Falcon-H1 arch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context, my motivation is very similar to yours. I'm trying to get Granite 4 support added which is also a combo of mamba2 and attention (GraniteMoeShared for the attention layers). My draft PR is #13550, but my most recent attempt to merge master and mamba2 failed pretty badly, so I'm still debugging why (and I'm guessing compilade is also working on doing it the right way).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, probably we have to introduce the notion of has_recurrent_state(il), has_attention(il) and these can be both true in such cases.

This makes sense. I'll give this a shot soon.

Of these uses, I think the two that would be the most problematic are n_embd_k/v_s since they're currently implemented using hparams.recurrent_layer and hparams.n_embd since that is used by both the unified and recurrent caches and I think might need to have different values for the two if they were being shared by the same layer.

@ggerganov with the removal of n_embd_k/v_s layer indexing, I think that just leaves hparams.n_embd. Do you have a quick thought on whether this will need to be split between attn/recurrent portions for a single layer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does H1 actually need different hparams.n_embd for the different layers? The hparams.n_embd is the size of the hidden state that goes through the layers. I haven't looked at the links, but I doubt this parameter would be vary with layers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not clear on that (haven't unboxed it far enough either). @younesbelkada I haven't dug into your fork deeply yet, but I'm guessing you may have had to tackle this question and might know which hparams would need to be different between the attention / recurrent portions of the cache for a given layer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the ability to specify custom filters in the constructor to llama_kv_cache_hybrid_recurrent which will default to checking hparams.recurrent_layer if not set. This should allow H1 to implement its own "always on" filter for all layers and then add a clause to create_memory in the model-specific section to construct the hybrid cache with that filter.

Copy link
Contributor

@younesbelkada younesbelkada Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @gabe-l-hart @ggerganov !
If I understand correctly from the last comment from @gabe-l-hart, we can already give it a shot for H1 now. If you confirm, we'll start from your synced branch you shared on the Mamba2 PR

Does H1 actually need different hparams.n_embd for the different layers?

I can confirm we actually use the same n_embed across all layers since we aggregate the hidden states of Attention and SSM mixers by summing them - however the KV cache will have different shapes

-> SSM cache shape: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon_h1/modeling_falcon_h1.py#L100-L117
-> KV cache shape: head_dim x num_kv_heads

Also, split llama_model_is_recurrent into llm_arch_is_recurrent in
llama-arch with llama_model_is_recurrent delegating to
llm_arch_is_recurrent. The same split is done for hybird. This is needed
because there are places where the llama_model has not yet been initialized
but we need to check if the model is recurrent (specifically for the
per-layer recurrent check array in hparams).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…s in hparams

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…l is recurrent

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
The implementation of the hybrid cache intentionally does not specify the
types of the child caches, so there was a naming mismatch with these
predicate functions that used "hybrid" to imply "hybrid recurrent."

Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
This follows the pattern in iswa where the two child caches are held
explicitly to support the case where a model requires a single attention
cache and a single recurrent cache where each layer uses exactly one of the
caches.

This is a rewrite of the more generic approach in the original hybrid cache
PR: ggml-org#13276

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
This includes a refactor of the create_memory logic to avoid needing to use
the arch enum explicitly unless a model needs explicit cache instantiation
logic beyond the standard logic for recurrent, hybrid, unified, and iswa.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
NOTE: I intentionally did not add support for s_mask since it will be going
away soon

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
…he interface

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
…empt

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch from b216ed3 to 1309384 Compare June 11, 2025 17:32
No longer needed now that unified isn't also supporting recurrent

ggml-org#13979 (comment)

Branch: HybridRecurrentCache
Now that it's not used at all in the unified cache, we don't need to use
the layer index to zero it out for attention layers.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Comment on lines 72 to 73
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't n_embd_k_gqa and n_embd_v_gqa equal to 0 for recurrent models? If yes, you can remove them from the recurrent implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that sounds right. It looks like the value here is ultimately determined by hparams.n_head_kv_arr which by default is populated as a copy of hparams.n_head_arr from %s.attention.head_count. For mamba, I see mamba.attention.head_count set to 0 and I suspect this would be true of all other recurrent models, but I'm not sure how to fully verify. I'll try removing them and see if the recurrent models I have work as expected, but I'll defer to @compilade on whether this is an "always" thing or a "just mamba* models" thing.

This is no longer needed now that there are separate implementations

ggml-org#13979 (comment)

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
This should help support architectures like Falcon H1 where there is
overlap between layers that need attention and recurrent caches.

ggml-org#13979 (comment)

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the naming of the recurrent cache is not great and it is causing some confusion. In a follow-up PR we have to do the following naming changes

  • Rename llama_kv_cache_recurrent -> llama_recurrent_state
  • Rename llama_kv_cache_hybrid_recurrent -> llama_memory_hybrid

This should be a good first step, although not ideal because "hybrid" will implicitly refer to "unified KV cache + recurrent state". For the future we have to figure out a naming scheme that would be able to differentiate between different types of hybrid memories. But this should be good enough for now.

@@ -508,7 +509,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
Copy link
Member

@ggerganov ggerganov Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is analogous to the build_attn_inp_ methods, so we have to model it in a similar way.

Replace this method with:

    // similar to build_attn_inp_kv_unified()
    llm_graph_input_rs * build_rs_inp() const;

Introduce new input class:

// similar to llm_graph_input_attn_kv_unified
// put the `s_copy` tensor in this class (similar to `kq_mask`)
class llm_graph_input_rs : public llm_graph_input_i;

In the future, this input class could be extended with additional input tensors that are needed by the recurrent cache if necessary (similar to the attention input classes).

Replace build_recurrent_state() and build_rwkv_shift_load() with overloads:

    // similar to build_attn()
    ggml_tensor * build_rs(
        llm_graph_input_rs * inp,
             ggml_cgraph * gf,
             ggml_tensor * s,             
                 int32_t   state_size,
                 int32_t   n_seqs,
                    bool   avoid_copies = false) const;

    // similar to build_attn()
    ggml_tensor * build_rwkv_token_shift_load(
        llm_graph_input_rs * inp,
             ggml_cgraph * gf,
      const llama_ubatch & ubatch,
                     int   il) const;

Copy link
Member

@ggerganov ggerganov Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this change is applied, we have to do a similar addition for the hybrid implementation. The basic pattern is that you need to introduce a new input class similar to llm_graph_input_attn_kv_unified and llm_graph_input_rs, but this one will contain inputs for both the attention and for the recurrent state.

So probably something like:

// this input class will have both the input tensors needed for the attention and for
// the recurrent state. see llm_graph_input_attn_kv_unified_iswa for example
class llm_graph_input_mem_hybrid : public llm_graph_input_i;

We then add overloads for build_attn(), build_rs() and build_rwkv_token_shift_load().

Comment on lines +1258 to +1260
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one
// encapsulated in inp
const auto * kv_state = inp->kv_state;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed. We have to aim to not mix the recurrent logic with the attention logic.

Comment on lines +149 to +150
const llama_kv_cache_unified_state_ptr state_attn;
const llama_kv_cache_recurrent_state_ptr state_recurrent;
Copy link
Member

@ggerganov ggerganov Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be changed like this, similar to the iswa state:

Suggested change
const llama_kv_cache_unified_state_ptr state_attn;
const llama_kv_cache_recurrent_state_ptr state_recurrent;
llama_memory_state_ptr state_attn;
llama_memory_state_ptr state_recurrent;

Comment on lines +572 to +574
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now let's hold-off adding this API. I don't think the llama_model_is_recurrent should exist either. The idea is that the user app should not have to know what is the underlying memory mechanism.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants