-
Notifications
You must be signed in to change notification settings - Fork 12.1k
Hybrid recurrent cache #13979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Hybrid recurrent cache #13979
Conversation
src/llama-graph.cpp
Outdated
const llama_kv_cache_unified_state * llm_graph_context::get_state_unified() const { | ||
const auto * umstate = dynamic_cast<const llama_kv_cache_unified_state *>(mstate); | ||
if (!umstate) { | ||
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate); | ||
if (hmstate) { | ||
umstate = hmstate->get_state_attn(); | ||
} | ||
} | ||
GGML_ASSERT(umstate); | ||
return umstate; | ||
} | ||
|
||
const llama_kv_cache_recurrent_state * llm_graph_context::get_state_recurrent() const { | ||
const auto * rmstate = dynamic_cast<const llama_kv_cache_recurrent_state *>(mstate); | ||
if (!rmstate) { | ||
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate); | ||
if (hmstate) { | ||
rmstate = hmstate->get_state_recurrent(); | ||
} | ||
} | ||
GGML_ASSERT(rmstate); | ||
return rmstate; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These dynamic casts should not be necessary. Instead you need a new llm_graph_context::build_attn_inp_kv_hybrid_recurrent()
method, similar to llm_graph_context::build_attn_inp_kv_unified_iswa()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm working through this now and a couple of questions are coming up:
- Would it be best to combine
build_inp_s_copy
withbuild_attn_inp_kv
for hybrid so that models call just one "build inputs" function, or keep them separate for simplicity? - For the
build_attn
methods, each has a correspondingllm_graph_input_attn_*
class. Thebuild_inp_s_*
methods don't have this pattern which would make this a bit harder to have code reuse. Are there plans to refactor that further @compilade? - In the
mamba2
branch,s_mask
seems to be totally removed. I'd prefer not to do all of the boilerplate for duplicatingbuild_inp_s_mask
for the hybrid recurrent case if that's definitely going to be going away. Is there any reason that might stick around past the merge ofmamba2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answering out of order, but it should still make sense:
2. For the
build_attn
methods, each has a correspondingllm_graph_input_attn_*
class. Thebuild_inp_s_*
methods don't have this pattern
They do follow this pattern, see
Line 190 in 7e00e60
class llm_graph_input_s_copy : public llm_graph_input_i { |
(this is on the current master
)
I think you might mean the build_attn_*
methods also return instances of llm_graph_input_attn_*
?
That seems to be directly related to llm_graph_context::build_attn()
which has multiple implementations which differ by the type of the first argument (e.g. for llm_graph_input_attn_kv_unified
, llm_graph_input_attn_no_cache
, etc.)
Are there plans to refactor that further @compilade?
Not really, outside of removing s_mask
(and related functions and classes) as part of #13834.
- Would it be best to combine
build_inp_s_copy
withbuild_attn_inp_kv
for hybrid so that models call just one "build inputs" function, or keep them separate for simplicity?
Personally, I think it would be simpler to keep them separate, because they are fundamentally different (one is intended to be used by build_copy_mask_state
(renamed to build_recurrent_state
in #13834), while the other is used by build_attn
), and they are pretty much independent, even in hybrid models (at least for Jamba, the recurrent and self-attention layers are mostly independent on that front).
I don't see how build_attn
would ever need s_copy
.
build_inp_s_copy
and build_inp_attn_kv_*
are called once at the beginning of the graph, while build_attn
and build_recurrent_state
are called once per layer (where applicable, and so usually different layers for both).
3. Is there any reason [
s_mask
] might stick around past the merge ofmamba2
?
No reason to keep it, s_mask
will be removed. Its functionality is redundant with s_copy
, and otherwise prevents minimizing unnecessary state copies. It was used to clear the states, but the same can be done through inp_s_copy
and clearing by copying a zero-ed state (which is the rs_z
'th state in the mamba2
branch (and #13834)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's super helpful! I was missing the distinction between build_attn_inp
and build_attn
which makes perfect sense.
Personally, I think it would be simpler to keep them separate
I agree on my personal gut feeling, so I'll go with this and see how it feels once complete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think this feels a lot cleaner now. For build_inp_s_copy
, I opted to add an optional parameter so that the caller can take ownership of casting the cache state rather than duplicating the function into build_inp_s_copy_hybrid
. That felt a little cleaner w.r.t. code reuse, but I'm happy to do a separate method if that's preferred.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like there's one more place that will need changing in build_copy_mask_state
(renamed to build_recurrent_state
on mamba2
). Similar to build_inp_s_copy
, I think the cleanest way to do this for code reuse is to add an optional parameter that, if unset, will use the current logic of casting mstate
.
// TODO: will the recurrent cache be in an undefined state at this point? | ||
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but that will be fixed in #13834
(Noting here in case this gets merged first so that I don't forget to update the comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I strip out this TODO at this point?
ab918bb
to
60ca3ba
Compare
@ggerganov I've noticed that the
|
This is an attempt to handle race conditions between /health returning OK and the other endpoints not returning successfully. ggml-org#13979 (comment) Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
I've tried adding retry logic for all requests in 39a93b3 to work around the race between |
The changes to the server tests should not be needed. Let's revert the commit for now and I'll investigate. |
39a93b3
to
60ca3ba
Compare
@ggerganov Thanks, it looks like those changes didn't fix the failures anyway, so definitely not the right fix. I've reset them out and can open an issue with details of what I see locally |
Issue for follow up on |
7958d84
to
3669876
Compare
I've rebased on #13834. Drafting for now until it's merged |
3669876
to
8c59841
Compare
That was quick! Undrafting now that #13834 is merged |
src/llama-graph.h
Outdated
@@ -242,7 +243,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { | |||
cparams(cparams), | |||
kv_state(kv_state) { | |||
} | |||
~llm_graph_input_attn_kv_unified() = default; | |||
virtual ~llm_graph_input_attn_kv_unified() = default; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this virtual
is needed. I think I added it when I was attempting to have the hybrid input inherit from the unified input
bb87dbf
to
b216ed3
Compare
uint32_t n_seq_max, | ||
bool offload) : | ||
hparams(model.hparams), | ||
kv_attn(new llama_kv_cache_unified( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gabe-l-hart - thanks for your huge work. Chiming in this PR to ask a question: how do you think we should handle cache formats where for every layer we have both the attention and SSM states ? This design seems to be tailored for architectures that have either SSM or Attention for each decoder block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @younesbelkada! That sounds like a really interesting architecture. Can you point to any papers or draft models that use this approach? I'd be really curious to check them out. The main advantage of the hybrid recurrent architecture where each layer is either SSM or Attention (but not both) is that you get the cache size scaling advantages of SSMs for most layers without sacrificing the ability to attend to all tokens for the whole model.
My first pass at a standalone hybrid cache (#13276) attempted to be a more generic representation of "each layer can have its own cache type" where the parent cache could have any number of children. This ran into a number of challenges during the recent interface refactoring, and ultimately was incompatible with some of the work to encapsulate the interactions with the cache, so at Georgi's suggestion, I pivoted to this less-abstract implementation.
Following the pattern established by iswa
and now hybrid_recurrent
, I would think implementing such a cache would look like:
- Creating another new subclass of
llama_memory_i
that owns the child caches (similar to hybrid_recurrent)- The various interfaces would probably be implemented with "AND" semantics instead of "OR" semantics for child layers
- Creating the corresponding subclass of
llama_memory_state_i
- Use a different set of layer filters when constructing the parent cache with the children such that the children apply to the right layers (which are now overlapping sets)
The place where I see the most incompatibility is in the hparams.recurrent_layer()
method which right now is a "yes/no" and if true
implies ! attention_layer
(explicitly in the n_embd_k_s
/ n_embd_v_s
methods). One possible solution here would be for the hparams
to hold some kind of get_cache_type(il) -> set{CACHE_TYPE}
method that would return a set of cache type enum values for a given layer index. This might also be able to replace the llm_arch_is_recurrent
/ llm_arch_is_hybrid_recurrent
interfaces.
@compilade @ggerganov do you foresee the need to allow a single layer to use multiple cache types as something we'll want to support in the future? If so, we might want to implement that more-generic interface in hparams
now to avoid an API breaking change down the road.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @gabe-l-hart for your detailed response 🙏
Currently there is an effort to try to integrate Falcon-H1 models into main llama.cpp: https://huggingface.co/collections/tiiuae/falcon-h1-6819f2795bc406da60fab8df (for more details you can check the technical blogpost here, this architecture employs a parallel hybrid design where you have both SSM and Attention on each decoder block) - we have a public fork of llama.cpp that we are currently maintaining here: https://github.com/tiiuae/llama.cpp-Falcon-H1 but we are refactoring the changes we made so that it will be compatible with the changes on master (including this PR and #9126) and merge the changes in llama.cpp master (related issue: #13681)
Right now we tried to build on top of this PR and the Mamba2 PR from @compilade and faced this issue so we are not sure what is the approach to take. What you suggested sounds good - we will be happy to work on whatever you think is the best suggestion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah! I actually think someone pointed me at your fork a week ago and I lost it in my exploded number of tabs. I'll definitely check it out and see if there are any places this should tweak to accommodate the Falcon-H1 arch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context, my motivation is very similar to yours. I'm trying to get Granite 4 support added which is also a combo of mamba2
and attention (GraniteMoeShared
for the attention layers). My draft PR is #13550, but my most recent attempt to merge master
and mamba2
failed pretty badly, so I'm still debugging why (and I'm guessing compilade
is also working on doing it the right way).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, probably we have to introduce the notion of has_recurrent_state(il), has_attention(il) and these can be both true in such cases.
This makes sense. I'll give this a shot soon.
Of these uses, I think the two that would be the most problematic are n_embd_k/v_s since they're currently implemented using hparams.recurrent_layer and hparams.n_embd since that is used by both the unified and recurrent caches and I think might need to have different values for the two if they were being shared by the same layer.
@ggerganov with the removal of n_embd_k/v_s
layer indexing, I think that just leaves hparams.n_embd
. Do you have a quick thought on whether this will need to be split between attn/recurrent portions for a single layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does H1 actually need different hparams.n_embd
for the different layers? The hparams.n_embd
is the size of the hidden state that goes through the layers. I haven't looked at the links, but I doubt this parameter would be vary with layers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not clear on that (haven't unboxed it far enough either). @younesbelkada I haven't dug into your fork deeply yet, but I'm guessing you may have had to tackle this question and might know which hparams would need to be different between the attention / recurrent portions of the cache for a given layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the ability to specify custom filters in the constructor to llama_kv_cache_hybrid_recurrent
which will default to checking hparams.recurrent_layer
if not set. This should allow H1 to implement its own "always on" filter for all layers and then add a clause to create_memory
in the model-specific section to construct the hybrid cache with that filter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @gabe-l-hart @ggerganov !
If I understand correctly from the last comment from @gabe-l-hart, we can already give it a shot for H1 now. If you confirm, we'll start from your synced branch you shared on the Mamba2 PR
Does H1 actually need different hparams.n_embd for the different layers?
I can confirm we actually use the same n_embed across all layers since we aggregate the hidden states of Attention and SSM mixers by summing them - however the KV cache will have different shapes
-> SSM cache shape: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon_h1/modeling_falcon_h1.py#L100-L117
-> KV cache shape: head_dim x num_kv_heads
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour
…s in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
…l is recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
The implementation of the hybrid cache intentionally does not specify the types of the child caches, so there was a naming mismatch with these predicate functions that used "hybrid" to imply "hybrid recurrent." Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
This follows the pattern in iswa where the two child caches are held explicitly to support the case where a model requires a single attention cache and a single recurrent cache where each layer uses exactly one of the caches. This is a rewrite of the more generic approach in the original hybrid cache PR: ggml-org#13276 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
NOTE: I intentionally did not add support for s_mask since it will be going away soon Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
…he interface Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
…empt Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
b216ed3
to
1309384
Compare
No longer needed now that unified isn't also supporting recurrent ggml-org#13979 (comment) Branch: HybridRecurrentCache
Now that it's not used at all in the unified cache, we don't need to use the layer index to zero it out for attention layers. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
src/llama-kv-cache-recurrent.cpp
Outdated
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); | ||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't n_embd_k_gqa
and n_embd_v_gqa
equal to 0
for recurrent models? If yes, you can remove them from the recurrent implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that sounds right. It looks like the value here is ultimately determined by hparams.n_head_kv_arr
which by default is populated as a copy of hparams.n_head_arr
from %s.attention.head_count
. For mamba
, I see mamba.attention.head_count
set to 0
and I suspect this would be true of all other recurrent models, but I'm not sure how to fully verify. I'll try removing them and see if the recurrent models I have work as expected, but I'll defer to @compilade on whether this is an "always" thing or a "just mamba* models" thing.
This is no longer needed now that there are separate implementations ggml-org#13979 (comment) Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. ggml-org#13979 (comment) Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like the naming of the recurrent cache is not great and it is causing some confusion. In a follow-up PR we have to do the following naming changes
- Rename
llama_kv_cache_recurrent
->llama_recurrent_state
- Rename
llama_kv_cache_hybrid_recurrent
->llama_memory_hybrid
This should be a good first step, although not ideal because "hybrid" will implicitly refer to "unified KV cache + recurrent state". For the future we have to figure out a naming scheme that would be able to differentiate between different types of hybrid memories. But this should be good enough for now.
@@ -508,7 +509,7 @@ struct llm_graph_context { | |||
ggml_tensor * build_inp_out_ids() const; | |||
ggml_tensor * build_inp_mean() const; | |||
ggml_tensor * build_inp_cls() const; | |||
ggml_tensor * build_inp_s_copy() const; | |||
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is analogous to the build_attn_inp_
methods, so we have to model it in a similar way.
Replace this method with:
// similar to build_attn_inp_kv_unified()
llm_graph_input_rs * build_rs_inp() const;
Introduce new input class:
// similar to llm_graph_input_attn_kv_unified
// put the `s_copy` tensor in this class (similar to `kq_mask`)
class llm_graph_input_rs : public llm_graph_input_i;
In the future, this input class could be extended with additional input tensors that are needed by the recurrent cache if necessary (similar to the attention input classes).
Replace build_recurrent_state()
and build_rwkv_shift_load()
with overloads:
// similar to build_attn()
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
// similar to build_attn()
ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When this change is applied, we have to do a similar addition for the hybrid implementation. The basic pattern is that you need to introduce a new input class similar to llm_graph_input_attn_kv_unified
and llm_graph_input_rs
, but this one will contain inputs for both the attention and for the recurrent state.
So probably something like:
// this input class will have both the input tensors needed for the attention and for
// the recurrent state. see llm_graph_input_attn_kv_unified_iswa for example
class llm_graph_input_mem_hybrid : public llm_graph_input_i;
We then add overloads for build_attn()
, build_rs()
and build_rwkv_token_shift_load()
.
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one | ||
// encapsulated in inp | ||
const auto * kv_state = inp->kv_state; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not needed. We have to aim to not mix the recurrent logic with the attention logic.
const llama_kv_cache_unified_state_ptr state_attn; | ||
const llama_kv_cache_recurrent_state_ptr state_recurrent; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These can be changed like this, similar to the iswa
state:
const llama_kv_cache_unified_state_ptr state_attn; | |
const llama_kv_cache_recurrent_state_ptr state_recurrent; | |
llama_memory_state_ptr state_attn; | |
llama_memory_state_ptr state_recurrent; |
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) | ||
LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now let's hold-off adding this API. I don't think the llama_model_is_recurrent
should exist either. The idea is that the user app should not have to know what is the underlying memory mechanism.
This is a re-opened version of #13904 after #13746 was merged
Description
This PR introduces the
llama_kv_cache_hybrid_recurrent
cache implementation. It follows the pattern ofllama_kv_cache_unified_iswa
by holding two child cache instances and implementing the interface logic such that it manages both correctly for the appropriate layers.Changes
The main change in this PR is the addition of
llama_kv_cache_hybrid_recurrent
inllama-kv-cache-hybrid-recurrent.*
. In addition to this, the PR does the following:llama_model_is_hybrid_recurrent
public API (akin tollama_model_is_recurrent
)LLM_KV_ATTENTION_LAYER_INDICES
as an hparam to hold the indices of the layers that should use attention (versus recurrent)iswa
, but that mechanism also isn't particularly extensible. It might be more appropriate to have a generic mechanism for indicating the type of caching to use for each layer, but that would start to approach the generic hybrid implementation that I originally attempted which ended up being too abstract (feat: Hybrid unified/recurrent cache #13276).llm_graph_context
that need a specific type of cache to use getters (get_state_unified
/get_state_recurrent
) that will properly handlellama_kv_cache_hybrid_recurrent
n_embd_k_s
/n_embd_v_s
layer-dependent and use layer indices when calling them in the existing cache implementationsllama_kv_cache_recurrent
llama_model::create_memory
to usellm_arch_is_recurrent
andllm_arch_is_hybrid_recurrent
rather than relying on adding models to theswitch
statement which was redundant with the implementation of these functions