diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 70be604e4b0d3..8f9cd652447ab 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -22,8 +22,9 @@ add_library(llama
             llama-io.cpp
             llama-kv-cache-unified.cpp
             llama-kv-cache-unified-iswa.cpp
-            llama-kv-cache-recurrent.cpp
             llama-memory.cpp
+            llama-memory-hybrid.cpp
+            llama-memory-recurrent.cpp
             llama-mmap.cpp
             llama-model-loader.cpp
             llama-model-saver.cpp
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index de8d289cf967e..0bc60565df12c 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -147,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
+    { LLM_KV_ATTENTION_LAYER_INDICES,                "%s.attention.layer_indices"                },
 
     { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
@@ -1816,3 +1817,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
 const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
     return LLM_TENSOR_INFOS.at(tensor);
 }
+
+bool llm_arch_is_recurrent(const llm_arch & arch) {
+    switch (arch) {
+        case LLM_ARCH_MAMBA:
+        case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
+        case LLM_ARCH_RWKV7:
+        case LLM_ARCH_ARWKV7:
+            return true;
+        default:
+            return false;
+    }
+}
+
+bool llm_arch_is_hybrid(const llm_arch & arch) {
+    // TODO: There are currently no hybrid models! Once there are, this will be
+    //  the place to identify them
+    switch (arch) {
+        default:
+            return false;
+    }
+}
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 3e8a61da3c13e..51b242c66b824 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -151,6 +151,7 @@ enum llm_kv {
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
+    LLM_KV_ATTENTION_LAYER_INDICES,
 
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -439,3 +440,6 @@ const char * llm_arch_name(llm_arch arch);
 llm_arch llm_arch_from_string(const std::string & name);
 
 const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
+
+bool llm_arch_is_recurrent(const llm_arch & arch);
+bool llm_arch_is_hybrid   (const llm_arch & arch);
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 337fb5cb0df36..65d98cbbb3987 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -6,7 +6,8 @@
 
 #include "llama-kv-cache-unified.h"
 #include "llama-kv-cache-unified-iswa.h"
-#include "llama-kv-cache-recurrent.h"
+#include "llama-memory-hybrid.h"
+#include "llama-memory-recurrent.h"
 
 #include <cassert>
 #include <cmath>
@@ -238,18 +239,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
+void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
-    const int64_t n_kv = kv_state->get_n_kv();
+    const int64_t n_rs = mem_state->get_n_rs();
 
     if (s_copy) {
         GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
         int32_t * data = (int32_t *) s_copy->data;
 
         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
-        for (uint32_t i = 0; i < n_kv; ++i) {
-            data[i] = kv_state->s_copy(i);
+        for (uint32_t i = 0; i < n_rs; ++i) {
+            data[i] = mem_state->s_copy(i);
         }
     }
 }
@@ -403,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
+    if (self_kq_mask) {
+        mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+    }
+
+    const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
+
+    if (s_copy) {
+        GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
+        int32_t * data = (int32_t *) s_copy->data;
+
+        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+        for (uint32_t i = 0; i < n_rs; ++i) {
+            data[i] = mem_state->get_state_recr()->s_copy(i);
+        }
+    }
+}
+
 //
 // llm_graph_context
 //
@@ -961,23 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_inp_s_copy() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
-
-    auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
-
-    const auto n_kv = kv_state->get_n_kv();
-
-    auto & cur = inp->s_copy;
-
-    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
-    ggml_set_input(cur);
-
-    res->add_input(std::move(inp));
-
-    return cur;
-}
-
 ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
     auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
 
@@ -1047,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
     return pos_bias;
 }
 
+llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
+    const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
+
+    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
+
+    {
+        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
+
+        const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
+
+        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        ggml_set_input(inp->self_kq_mask);
+
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    }
+
+    {
+        const auto n_rs = mem_state->get_state_recr()->get_n_rs();
+
+        inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
+        ggml_set_input(inp->s_copy);
+    }
+
+    return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
+}
+
 ggml_tensor * llm_graph_context::build_attn_mha(
          ggml_cgraph * gf,
          ggml_tensor * q,
@@ -1291,36 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
-
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
-
-    {
-        const auto n_kv = kv_state->get_base()->get_n_kv();
-
-        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask, "KQ_mask", -1);
-        ggml_set_input(inp->self_kq_mask);
-
-        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
-    }
-
-    {
-        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
-
-        const auto n_kv = kv_state->get_swa()->get_n_kv();
-
-        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
-        ggml_set_input(inp->self_kq_mask_swa);
-
-        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
-    }
-
-    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
-}
-
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_kv_unified_iswa * inp,
         ggml_cgraph * gf,
@@ -1430,20 +1429,99 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_recurrent_state(
-         ggml_cgraph * gf,
-         ggml_tensor * s,
-         ggml_tensor * state_copy,
-             int32_t   state_size,
-             int32_t   n_seqs,
-                bool   avoid_copies) const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
-
-    const auto n_kv    = kv_state->get_n_kv();
-    const auto kv_head = kv_state->get_head();
-    const auto rs_zero = kv_state->get_rs_z();
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_mem_hybrid * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+        ggml_tensor * v_mla,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
+    const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
+
+    // store to KV cache
+    {
+        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
+    }
+
+    const auto & kq_mask = inp->get_kq_mask();
+
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = kv_state->get_k(ctx0, il);
+    ggml_tensor * v = kv_state->get_v(ctx0, il);
+
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+        if (arch == LLM_ARCH_GLM4) {
+            // GLM4 seems to have numerical issues with half-precision accumulators
+            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+        }
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
+llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
+    const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
+
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
 
-    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
+    {
+        const auto n_kv = kv_state->get_base()->get_n_kv();
+
+        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        ggml_set_input(inp->self_kq_mask);
+
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    }
+
+    {
+        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
+
+        const auto n_kv = kv_state->get_swa()->get_n_kv();
+
+        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
+        ggml_set_input(inp->self_kq_mask_swa);
+
+        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+    }
+
+    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_rs(
+        ggml_cgraph * gf,
+        ggml_tensor * s,
+        ggml_tensor * state_copy,
+            int32_t   state_size,
+            int32_t   n_seqs,
+           uint32_t   n_kv,
+           uint32_t   kv_head,
+           uint32_t   kv_size,
+            int32_t   rs_zero,
+               bool   avoid_copies) const {
+
+    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
 
     // Clear a single state which will then be copied to the other cleared states.
     // Note that this is a no-op when the view is zero-sized.
@@ -1474,22 +1552,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
     return output_states;
 }
 
+llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
+    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+
+    auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
+
+    const auto n_rs = kv_state->get_n_rs();
+
+    inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
+    ggml_set_input(inp->s_copy);
+
+    return (llm_graph_input_rs *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_rs(
+        llm_graph_input_rs * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * s,
+            int32_t   state_size,
+            int32_t   n_seqs,
+               bool   avoid_copies) const {
+    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+
+    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
+}
+
+ggml_tensor * llm_graph_context::build_rs(
+        llm_graph_input_mem_hybrid * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * s,
+            int32_t   state_size,
+            int32_t   n_seqs,
+               bool   avoid_copies) const {
+    const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
+
+    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
+}
+
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
-         ggml_cgraph * gf,
-         ggml_tensor * state_copy,
-  const llama_ubatch & ubatch,
+    llm_graph_input_rs * inp,
+           ggml_cgraph * gf,
+    const llama_ubatch & ubatch,
                  int   il) const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
     const auto token_shift_count = hparams.token_shift_count;
 
     const int64_t n_seqs  = ubatch.n_seqs;
 
-    ggml_tensor * token_shift_all = kv_state->get_k_l(il);
+    ggml_tensor * token_shift_all = kv_state->get_r_l(il);
 
-    ggml_tensor * token_shift = build_recurrent_state(
-            gf, token_shift_all, state_copy,
-            hparams.n_embd_k_s(), n_seqs);
+    ggml_tensor * token_shift = build_rs(
+            inp, gf, token_shift_all,
+            hparams.n_embd_r(), n_seqs);
 
     token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
 
@@ -1500,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
          ggml_tensor * token_shift,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
     const auto token_shift_count = hparams.token_shift_count;
     const auto n_embd = hparams.n_embd;
@@ -1512,7 +1627,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
     return ggml_cpy(
         ctx0,
         ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
-        ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
+        ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
     );
 }
 
diff --git a/src/llama-graph.h b/src/llama-graph.h
index 87813119b1a3c..58845e284abed 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -21,7 +21,8 @@ struct llama_memory_state_i;
 
 class llama_kv_cache_unified_state;
 class llama_kv_cache_unified_iswa_state;
-class llama_kv_cache_recurrent_state;
+class llama_memory_recurrent_state;
+class llama_memory_hybrid_state;
 
 // certain models (typically multi-modal) can produce different types of graphs
 enum llm_graph_type {
@@ -188,16 +189,16 @@ class llm_graph_input_cls : public llm_graph_input_i {
     const llama_cparams & cparams;
 };
 
-class llm_graph_input_s_copy : public llm_graph_input_i {
+class llm_graph_input_rs : public llm_graph_input_i {
 public:
-    llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
-    virtual ~llm_graph_input_s_copy() = default;
+    llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
+    virtual ~llm_graph_input_rs() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_copy; // I32 [kv_size]
 
-    const llama_kv_cache_recurrent_state * kv_state;
+    const llama_memory_recurrent_state * mem_state;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -300,6 +301,33 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
     const llama_cross * cross = nullptr;
 };
 
+class llm_graph_input_mem_hybrid : public llm_graph_input_i {
+public:
+    llm_graph_input_mem_hybrid(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            const llama_memory_hybrid_state * mem_state) :
+        hparams(hparams),
+        cparams(cparams),
+        mem_state(mem_state) {
+    }
+    virtual ~llm_graph_input_mem_hybrid() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * s_copy; // I32 [kv_size]
+
+    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
+
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+
+    const llama_memory_hybrid_state * mem_state;
+};
+
 //
 // llm_graph_result
 //
@@ -508,13 +536,14 @@ struct llm_graph_context {
     ggml_tensor * build_inp_out_ids() const;
     ggml_tensor * build_inp_mean() const;
     ggml_tensor * build_inp_cls() const;
-    ggml_tensor * build_inp_s_copy() const;
 
     ggml_tensor * build_inp_cross_embd() const;
     ggml_tensor * build_inp_pos_bucket_enc() const;
     ggml_tensor * build_inp_pos_bucket_dec() const;
     ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
 
+    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
+
     //
     // attention
     //
@@ -589,22 +618,62 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
+    ggml_tensor * build_attn(
+            llm_graph_input_mem_hybrid * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
+            ggml_tensor * kq_b,
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
+                  float   kq_scale,
+                    int   il) const;
     //
     // recurrent
     //
 
-    ggml_tensor * build_recurrent_state(
-             ggml_cgraph * gf,
-             ggml_tensor * s,
-             ggml_tensor * state_copy,
-                 int32_t   state_size,
-                 int32_t   n_seqs,
-                    bool   avoid_copies = false) const;
+    // TODO: avoid notion of "kv"
+    // TODO: move this implementation to llama_memory_recurrent.
+    //       this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
+    //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
+    //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
+    //         `llama_memory_recurrent`
+    ggml_tensor * build_rs(
+            ggml_cgraph * gf,
+            ggml_tensor * s,
+            ggml_tensor * state_copy,
+                int32_t   state_size,
+                int32_t   n_seqs,
+               uint32_t   n_kv,
+               uint32_t   kv_head,
+               uint32_t   kv_size,
+                int32_t   rs_zero,
+                   bool   avoid_copies = false) const;
+
+    llm_graph_input_rs * build_rs_inp() const;
+
+    ggml_tensor * build_rs(
+            llm_graph_input_rs * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * s,
+                int32_t   state_size,
+                int32_t   n_seqs,
+                   bool   avoid_copies = false) const;
+
+    ggml_tensor * build_rs(
+            llm_graph_input_mem_hybrid * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * s,
+                int32_t   state_size,
+                int32_t   n_seqs,
+                   bool   avoid_copies = false) const;
 
     ggml_tensor * build_rwkv_token_shift_load(
-             ggml_cgraph * gf,
-             ggml_tensor * state_copy,
-      const llama_ubatch & ubatch,
+        llm_graph_input_rs * inp,
+               ggml_cgraph * gf,
+        const llama_ubatch & ubatch,
                      int   il) const;
 
     ggml_tensor * build_rwkv_token_shift_store(
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
index 1499eb08a5dd9..b40566ced99ee 100644
--- a/src/llama-hparams.cpp
+++ b/src/llama-hparams.cpp
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
     return n_embd_head_v * n_head_kv;
 }
 
-uint32_t llama_hparams::n_embd_k_s() const {
+uint32_t llama_hparams::n_embd_r() const {
     if (wkv_head_size != 0) {
         // for RWKV models
         return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
     return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
 }
 
-uint32_t llama_hparams::n_embd_v_s() const {
+uint32_t llama_hparams::n_embd_s() const {
     if (wkv_head_size != 0) {
         // corresponds to RWKV's wkv_states size
         return n_embd * wkv_head_size;
@@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
     return ssm_d_state * ssm_d_inner;
 }
 
+bool llama_hparams::is_recurrent(uint32_t il) const {
+    return recurrent_layer_arr[il];
+}
+
 bool llama_hparams::is_swa(uint32_t il) const {
     if (il < n_layer) {
         return swa_layers[il];
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index b2bcb8b01a18b..82bb5b6084946 100644
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
@@ -115,6 +115,9 @@ struct llama_hparams {
     uint32_t ssm_d_state = 0;
     uint32_t ssm_dt_rank = 0;
 
+    // for hybrid state space models
+    std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
+
     bool ssm_dt_b_c_rms = false;
 
     float f_clamp_kqv      = 0.0f;
@@ -181,10 +184,13 @@ struct llama_hparams {
 
     // dimension of the rolling state embeddings
     // corresponds to Mamba's conv_states size or RWKV's token_shift states size
-    uint32_t n_embd_k_s() const;
+    uint32_t n_embd_r() const;
 
     // dimension of the recurrent state embeddings
-    uint32_t n_embd_v_s() const;
+    uint32_t n_embd_s() const;
+
+    // whether or not the given layer is recurrent (for hybrid models)
+    bool is_recurrent(uint32_t il) const;
 
     bool is_swa(uint32_t il) const;
 };
diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp
index a4a4c2b1b859d..a869b1de8c2a3 100644
--- a/src/llama-kv-cache-unified-iswa.cpp
+++ b/src/llama-kv-cache-unified-iswa.cpp
@@ -197,21 +197,19 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
-        llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
-    state_base = kv->get_base()->init_full();
-    state_swa  = kv->get_swa ()->init_full();
-
-    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+        llama_kv_cache_unified_iswa * kv) :
+    state_base(kv->get_base()->init_full()),
+    state_swa (kv->get_swa ()->init_full()),
+    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 }
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
         llama_kv_cache_unified_iswa * kv,
         llama_context * lctx,
-        bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
-    state_base = kv->get_base()->init_update(lctx, optimize);
-    state_swa  = kv->get_swa ()->init_update(lctx, optimize);
-
-    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+        bool optimize) :
+    state_base(kv->get_base()->init_update(lctx, optimize)),
+    state_swa (kv->get_swa ()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 }
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
@@ -219,15 +217,13 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
         llama_sbatch sbatch,
         std::vector<uint32_t> heads_base,
         std::vector<uint32_t> heads_swa,
-        std::vector<llama_ubatch> ubatches)
-        : status(LLAMA_MEMORY_STATUS_SUCCESS),
-        sbatch(std::move(sbatch)),
-        ubatches(std::move(ubatches)) {
+        std::vector<llama_ubatch> ubatches) :
+    sbatch(std::move(sbatch)),
+    ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
-    state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches));
-
-    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+    state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
+    state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches)),
+    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 }
 
 llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h
index 6e941e1a41b88..813eaf39b25b0 100644
--- a/src/llama-kv-cache-unified-iswa.h
+++ b/src/llama-kv-cache-unified-iswa.h
@@ -117,8 +117,6 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
     const llama_kv_cache_unified_state * get_swa()  const;
 
 private:
-    llama_memory_status status;
-
     //llama_kv_cache_unified_iswa * kv;
 
     llama_sbatch sbatch;
@@ -128,6 +126,8 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
 
     std::vector<llama_ubatch> ubatches;
 
-    llama_memory_state_ptr state_base;
-    llama_memory_state_ptr state_swa;
+    const llama_memory_state_ptr state_base;
+    const llama_memory_state_ptr state_swa;
+
+    const llama_memory_status status;
 };
diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp
index 3b37679859d39..d4412288925c3 100644
--- a/src/llama-kv-cache-unified.cpp
+++ b/src/llama-kv-cache-unified.cpp
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
             continue;
         }
 
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
         const char * dev_name = "CPU";
 
@@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
     for (const auto & layer : layers) {
         const uint32_t il = layer.il;
 
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
         // Write key type
         const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
 
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             // Write value type
             const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
 
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             // Write value type
             const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
     for (const auto & layer : layers) {
         const uint32_t il = layer.il;
 
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
         // Read type of key
         int32_t k_type_i_ref;
@@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
 
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             // Read type of value
             int32_t v_type_i_ref;
@@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
 
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             // Read type of value
             int32_t v_type_i_ref;
diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp
new file mode 100644
index 0000000000000..d4b260db4c8e7
--- /dev/null
+++ b/src/llama-memory-hybrid.cpp
@@ -0,0 +1,247 @@
+#include "llama-memory-hybrid.h"
+
+#include "llama-impl.h"
+#include "llama-model.h"
+#include "llama-context.h"
+
+//
+// llama_memory_hybrid
+//
+
+llama_memory_hybrid::llama_memory_hybrid(
+    const llama_model & model,
+                         /* attn */
+            ggml_type    type_k,
+            ggml_type    type_v,
+                 bool    v_trans,
+             uint32_t    kv_size,
+             uint32_t    n_pad,
+             uint32_t    n_swa,
+       llama_swa_type    swa_type,
+                         /* recurrent */
+            ggml_type    type_r,
+            ggml_type    type_s,
+             uint32_t    rs_size,
+                         /* common */
+             uint32_t    n_seq_max,
+                 bool    offload,
+                         /* layer filters */
+      layer_filter_cb && filter_attn,
+      layer_filter_cb && filter_recr) :
+    hparams(model.hparams),
+    mem_attn(new llama_kv_cache_unified(
+        model,
+        filter_attn == nullptr ?
+            [&](int32_t il) { return !model.hparams.is_recurrent(il); }
+            : filter_attn,
+        type_k,
+        type_v,
+        v_trans,
+        offload,
+        kv_size,
+        n_seq_max,
+        n_pad,
+        n_swa,
+        swa_type
+    )),
+    mem_recr(new llama_memory_recurrent(
+        model,
+        filter_recr == nullptr ?
+            [&](int32_t il) { return model.hparams.is_recurrent(il); }
+            : filter_recr,
+        type_r,
+        type_s,
+        offload,
+        rs_size,
+        n_seq_max
+    )) {}
+
+llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
+
+    // since this includes a recurrent cache, we cannot use split_simple
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
+
+    // follow the recurrent pattern for creating the ubatch splits
+    std::vector<llama_ubatch> ubatches;
+    while (sbatch.n_tokens > 0) {
+        llama_ubatch ubatch;
+
+        if (embd_pooled) {
+            // Pooled embeddings cannot be split across ubatches (yet)
+            ubatch = sbatch.split_seq(n_ubatch);
+        } else {
+            ubatch = sbatch.split_equal(n_ubatch);
+        }
+
+        ubatches.push_back(ubatch);
+    }
+
+    // prepare the recurrent batches first
+    if (!mem_recr->prepare(ubatches)) {
+        // TODO: will the recurrent cache be in an undefined state at this point?
+        LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
+        return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    }
+
+    // prepare the attention cache
+    auto heads_attn = mem_attn->prepare(ubatches);
+    if (heads_attn.empty()) {
+        LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
+        return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    }
+
+    return std::make_unique<llama_memory_hybrid_state>(
+        this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
+}
+
+llama_memory_state_ptr llama_memory_hybrid::init_full() {
+    return std::make_unique<llama_memory_hybrid_state>(this);
+}
+
+llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
+}
+
+bool llama_memory_hybrid::get_can_shift() const {
+    // Shifting is trivially supported for recurrent
+    return mem_attn->get_can_shift();
+}
+
+void llama_memory_hybrid::clear(bool data) {
+    mem_attn->clear(data);
+    mem_recr->clear(data);
+}
+
+bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    // Try removing from the recurrent cache first since it may fail. If it does
+    // fail, the cache will not have been mutated.
+    if (!mem_recr->seq_rm(seq_id, p0, p1)) {
+        return false;
+    }
+    return mem_attn->seq_rm(seq_id, p0, p1);
+}
+
+void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
+    mem_attn->seq_keep(seq_id);
+    mem_recr->seq_keep(seq_id);
+}
+
+void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    mem_attn->seq_add(seq_id, p0, p1, shift);
+    mem_recr->seq_add(seq_id, p0, p1, shift);
+}
+
+void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    mem_attn->seq_div(seq_id, p0, p1, d);
+    mem_recr->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
+    // the min of the total cache is the max of the two caches' min values
+    return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
+}
+
+llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
+    // the max of the total cache is the min of the two caches' max values
+    return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
+}
+
+void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+    mem_attn->state_write(io, seq_id);
+    mem_recr->state_write(io, seq_id);
+}
+
+void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+    mem_attn->state_read(io, seq_id);
+    mem_recr->state_read(io, seq_id);
+}
+
+llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
+    return mem_attn.get();
+}
+
+llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
+    return mem_recr.get();
+}
+
+llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
+
+llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
+    state_attn(mem->get_mem_attn()->init_full()),
+    state_recr(mem->get_mem_recr()->init_full()),
+    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+}
+
+llama_memory_hybrid_state::llama_memory_hybrid_state(
+        llama_memory_hybrid * mem,
+              llama_context * lctx,
+                       bool   optimize) :
+    state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
+    state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+}
+
+llama_memory_hybrid_state::llama_memory_hybrid_state(
+              llama_memory_hybrid * mem,
+                     llama_sbatch   sbatch,
+            std::vector<uint32_t>   heads_attn,
+        std::vector<llama_ubatch>   ubatches) :
+    sbatch(std::move(sbatch)),
+    ubatches(std::move(ubatches)),
+    // note: here we copy the ubatches. not sure if this is ideal
+    state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
+    state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {},                        this->ubatches)),
+    status(LLAMA_MEMORY_STATUS_SUCCESS) {
+}
+
+bool llama_memory_hybrid_state::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    state_attn->next();
+    state_recr->next();
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_memory_hybrid_state::apply() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    bool res = true;
+
+    res = res & state_attn->apply();
+    res = res & state_recr->apply();
+
+    return res;
+}
+
+std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return sbatch.out_ids;
+}
+
+llama_memory_status llama_memory_hybrid_state::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+    return ubatches[i_next];
+}
+
+const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
+    return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
+}
+
+const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
+    return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
+}
diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h
new file mode 100644
index 0000000000000..b5700c5225f18
--- /dev/null
+++ b/src/llama-memory-hybrid.h
@@ -0,0 +1,143 @@
+#pragma once
+
+#include "llama-batch.h"
+#include "llama-graph.h"
+#include "llama-kv-cache-unified.h"
+#include "llama-memory.h"
+#include "llama-memory-recurrent.h"
+
+#include <memory>
+#include <vector>
+
+//
+// llama_memory_hybrid
+//
+
+// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
+//   support models where each layer may be either attention-based or recurrent
+
+class llama_memory_hybrid : public llama_memory_i {
+public:
+
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
+    llama_memory_hybrid(
+        const llama_model & model,
+                            /* attn */
+                ggml_type    type_k,
+                ggml_type    type_v,
+                     bool    v_trans,
+                 uint32_t    kv_size,
+                 uint32_t    n_pad,
+                 uint32_t    n_swa,
+           llama_swa_type    swa_type,
+                             /* recurrent */
+                ggml_type    type_r,
+                ggml_type    type_s,
+                 uint32_t    rs_size,
+                             /* common */
+                 uint32_t    n_seq_max,
+                     bool    offload,
+                             /* layer filters */
+          layer_filter_cb && filter_attn = nullptr,
+          layer_filter_cb && filter_recr = nullptr);
+
+    ~llama_memory_hybrid() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled) override;
+
+    llama_memory_state_ptr init_full() override;
+
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    bool get_can_shift() const override;
+
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
+
+    //
+    // llama_memory_hybrid specific API
+    //
+
+    llama_kv_cache_unified * get_mem_attn() const;
+    llama_memory_recurrent * get_mem_recr() const;
+
+private:
+    const llama_hparams & hparams;
+
+    const std::unique_ptr<llama_kv_cache_unified> mem_attn;
+    const std::unique_ptr<llama_memory_recurrent> mem_recr;
+};
+
+class llama_memory_hybrid_state : public llama_memory_state_i {
+public:
+    // init failure
+    explicit llama_memory_hybrid_state(llama_memory_status status);
+
+    // init full
+    explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
+
+    // init update
+    explicit llama_memory_hybrid_state(
+        llama_memory_hybrid * mem,
+              llama_context * lctx,
+                       bool   optimize);
+
+    // init success
+    llama_memory_hybrid_state(
+              llama_memory_hybrid * mem,
+                     llama_sbatch   sbatch,
+            std::vector<uint32_t>   heads_attn,
+        std::vector<llama_ubatch>   ubatches);
+
+    ~llama_memory_hybrid_state() = default;
+
+    bool next()  override;
+    bool apply() override;
+
+    std::vector<int64_t> & out_ids() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_memory_hybrid_state
+    //
+
+    const llama_kv_cache_unified_state * get_state_attn() const;
+    const llama_memory_recurrent_state * get_state_recr() const;
+
+private:
+    llama_sbatch sbatch;
+
+    // the index of the next ubatch to process
+    size_t i_next = 0;
+
+    std::vector<llama_ubatch> ubatches;
+
+    const llama_memory_state_ptr state_attn;
+    const llama_memory_state_ptr state_recr;
+
+    const llama_memory_status status;
+};
diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-memory-recurrent.cpp
similarity index 67%
rename from src/llama-kv-cache-recurrent.cpp
rename to src/llama-memory-recurrent.cpp
index 8f6f120f682b7..c4f9a6f1ddc98 100644
--- a/src/llama-kv-cache-recurrent.cpp
+++ b/src/llama-memory-recurrent.cpp
@@ -1,4 +1,4 @@
-#include "llama-kv-cache-recurrent.h"
+#include "llama-memory-recurrent.h"
 
 #include "llama-impl.h"
 #include "llama-io.h"
@@ -12,27 +12,28 @@
 #include <stdexcept>
 
 //
-// llama_kv_cache_recurrent
+// llama_memory_recurrent
 //
 
-llama_kv_cache_recurrent::llama_kv_cache_recurrent(
-        const llama_model & model,
-                ggml_type   type_k,
-                ggml_type   type_v,
-                     bool   offload,
-                 uint32_t   kv_size,
-                 uint32_t   n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
+llama_memory_recurrent::llama_memory_recurrent(
+        const llama_model &  model,
+          layer_filter_cb && filter,
+                ggml_type    type_r,
+                ggml_type    type_s,
+                     bool    offload,
+                 uint32_t    mem_size,
+                 uint32_t    n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
     const int32_t n_layer = hparams.n_layer;
 
-    LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
-            __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
+    LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
+            __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
 
     head = 0;
-    size = kv_size;
+    size = mem_size;
     used = 0;
 
     cells.clear();
-    cells.resize(kv_size);
+    cells.resize(mem_size);
 
     // create a context for each buffer type
     std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -59,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
         return it->second;
     };
 
-    k_l.reserve(n_layer);
-    v_l.reserve(n_layer);
+    r_l.resize(n_layer);
+    s_l.resize(n_layer);
 
     for (int i = 0; i < n_layer; i++) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
+        if (filter && !filter(i)) {
+            LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
+            continue;
+        }
 
         const char * dev_name = "CPU";
 
@@ -84,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
             throw std::runtime_error("failed to create ggml context for kv cache");
         }
 
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
-        ggml_format_name(k, "cache_k_l%d", i);
-        ggml_format_name(v, "cache_v_l%d", i);
-        k_l.push_back(k);
-        v_l.push_back(v);
+        ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
+        ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
+        ggml_format_name(r, "cache_r_l%d", i);
+        ggml_format_name(s, "cache_s_l%d", i);
+        r_l[i] = r;
+        s_l[i] = s;
     }
 
     // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -107,17 +110,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
     }
 
     {
-        const size_t memory_size_k = size_k_bytes();
-        const size_t memory_size_v = size_v_bytes();
+        const size_t memory_size_r = size_r_bytes();
+        const size_t memory_size_s = size_s_bytes();
 
-        LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
-                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
-                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
+        LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
+                ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
+                ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
     }
 }
 
-void llama_kv_cache_recurrent::clear(bool data) {
+void llama_memory_recurrent::clear(bool data) {
     for (int32_t i = 0; i < (int32_t) size; ++i) {
         cells[i].pos = -1;
         cells[i].seq_id.clear();
@@ -135,7 +138,7 @@ void llama_kv_cache_recurrent::clear(bool data) {
     }
 }
 
-bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
     uint32_t new_head = size;
 
     if (p0 < 0) {
@@ -154,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
     if (0 <= seq_id) {
         int32_t & tail_id = cells[seq_id].tail;
         if (tail_id >= 0) {
-            const kv_cell & cell = cells[tail_id];
+            const auto & cell = cells[tail_id];
             // partial intersection is invalid
             if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
                 return false;
@@ -202,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
     return true;
 }
 
-void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
     if (seq_id_src == seq_id_dst) {
         return;
     }
@@ -216,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
     }
 
     if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
-        kv_cell & tail_src = cells[seq_id_src];
-        kv_cell & tail_dst = cells[seq_id_dst];
+        auto & tail_src = cells[seq_id_src];
+        auto & tail_dst = cells[seq_id_dst];
         if (tail_dst.tail >= 0) {
             // clear destination seq_id if it wasn't empty
-            kv_cell & cell_dst = cells[tail_dst.tail];
+            auto & cell_dst = cells[tail_dst.tail];
 
             cell_dst.seq_id.erase(seq_id_dst);
             tail_dst.tail = -1;
@@ -231,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
             }
         }
         if (tail_src.tail >= 0) {
-            kv_cell & cell_src = cells[tail_src.tail];
+            auto & cell_src = cells[tail_src.tail];
 
             cell_src.seq_id.insert(seq_id_dst);
             tail_dst.tail = tail_src.tail;
@@ -239,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
     }
 }
 
-void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
+void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
     uint32_t new_head = size;
 
     for (uint32_t i = 0; i < size; ++i) {
@@ -271,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
     }
 }
 
-void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
     if (shift == 0) {
         return;
     }
@@ -293,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
     if (0 <= seq_id && seq_id < (int64_t) size) {
         const int32_t tail_id = cells[seq_id].tail;
         if (tail_id >= 0) {
-            kv_cell & cell = cells[tail_id];
+            auto & cell = cells[tail_id];
             if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
                 cell.pos += shift;
             }
@@ -301,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
     }
 }
 
-void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
     if (d == 1) {
         return;
     }
@@ -323,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
     if (0 <= seq_id && seq_id < (int64_t) size) {
         const int32_t tail_id = cells[seq_id].tail;
         if (tail_id >= 0) {
-            kv_cell & cell = cells[tail_id];
+            auto & cell = cells[tail_id];
             if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
                 cell.pos /= d;
             }
@@ -331,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
     }
 }
 
-llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
+llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
     llama_pos result = std::numeric_limits<llama_pos>::max();
 
     for (uint32_t i = 0; i < size; ++i) {
@@ -347,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
+llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     llama_pos result = -1;
 
     for (uint32_t i = 0; i < size; ++i) {
@@ -359,7 +362,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
+llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
     auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
 
     std::vector<llama_ubatch> ubatches;
@@ -378,24 +381,24 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
     }
 
     if (!prepare(ubatches)) {
-        return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+        return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
     }
 
-    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
+    return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
 }
 
-llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
-    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
+llama_memory_state_ptr llama_memory_recurrent::init_full() {
+    return std::make_unique<llama_memory_recurrent_state>(this);
 }
 
-llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
+llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
     GGML_UNUSED(lctx);
     GGML_UNUSED(optimize);
 
-    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
+    return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
 }
 
-bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
+bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
     // simply remember the full state because it is very small for this type of cache
     // TODO: optimize
     auto org_cells = cells;
@@ -419,7 +422,7 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
     return success;
 }
 
-bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
+bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
     const uint32_t n_seqs = ubatch.n_seqs;
 
     const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
@@ -453,9 +456,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
                 return false;
             }
             if (j > 0) {
-                kv_cell & seq = cells[seq_id];
+                auto & seq = cells[seq_id];
                 if (seq.tail >= 0) {
-                    kv_cell & cell = cells[seq.tail];
+                    auto & cell = cells[seq.tail];
                     // clear cells from seq_ids that become shared
                     // (should not normally happen, but let's handle it anyway)
                     cell.seq_id.erase(seq_id);
@@ -475,7 +478,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
         std::vector<int32_t> tails_verif;
         tails_verif.assign(size, -1);
         for (uint32_t i = 0; i < size; ++i) {
-            kv_cell & cell = cells[i];
+            auto & cell = cells[i];
             for (llama_seq_id seq_id : cell.seq_id) {
                 if (tails_verif[seq_id] != -1) {
                     LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -496,7 +499,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
 
     for (uint32_t i = 0; i < size; ++i) {
         if (next_empty_cell >= size) { next_empty_cell -= size; }
-        kv_cell & cell = cells[next_empty_cell];
+        auto & cell = cells[next_empty_cell];
         if (cell.is_empty()) { break; }
         next_empty_cell += 1;
     }
@@ -504,20 +507,20 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
     // find usable cell range
     for (uint32_t s = 0; s < n_seqs; ++s) {
         const llama_seq_id seq_id = ubatch.seq_id[s][0];
-        kv_cell & seq_meta = cells[seq_id];
+        auto & seq_meta = cells[seq_id];
         bool has_cell = false;
         if (seq_meta.tail >= 0) {
-            kv_cell & cell = cells[seq_meta.tail];
+            auto & cell = cells[seq_meta.tail];
             GGML_ASSERT(cell.has_seq_id(seq_id));
             // does this seq_id "own" the cell?
             if (cell.seq_id.size() == 1) { has_cell = true; }
         }
         if (!has_cell) {
-            kv_cell & empty_cell = cells[next_empty_cell];
+            auto & empty_cell = cells[next_empty_cell];
             GGML_ASSERT(empty_cell.is_empty());
             // copy old tail into the empty cell
             if (seq_meta.tail >= 0) {
-                kv_cell & orig_cell = cells[seq_meta.tail];
+                auto & orig_cell = cells[seq_meta.tail];
                 empty_cell.pos = orig_cell.pos;
                 empty_cell.src = orig_cell.src;
                 orig_cell.seq_id.erase(seq_id);
@@ -530,7 +533,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
                 for (uint32_t i = 0; i < size; ++i) {
                     next_empty_cell += 1;
                     if (next_empty_cell >= size) { next_empty_cell -= size; }
-                    kv_cell & cell = cells[next_empty_cell];
+                    auto & cell = cells[next_empty_cell];
                     if (cell.is_empty()) { break; }
                 }
             }
@@ -544,8 +547,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
         const int32_t dst_id = s + min;
         const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
         if (dst_id != src_id) {
-            kv_cell & dst_cell = cells[dst_id];
-            kv_cell & src_cell = cells[src_id];
+            auto & dst_cell = cells[dst_id];
+            auto & src_cell = cells[src_id];
 
             std::swap(dst_cell.pos, src_cell.pos);
             std::swap(dst_cell.src, src_cell.src);
@@ -567,7 +570,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
     for (uint32_t s = 0; s < n_seqs; ++s) {
         const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
         const int32_t cell_id = s + min;
-        kv_cell & cell = cells[cell_id];
+        auto & cell = cells[cell_id];
 
         if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
             // What should happen when the pos backtracks or skips a value?
@@ -620,18 +623,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
     head = min;
     n    = max - min + 1;
     used = std::count_if(cells.begin(), cells.end(),
-        [](const kv_cell & cell){ return !cell.is_empty(); });
+        [](const mem_cell & cell){ return !cell.is_empty(); });
 
     // sanity check
     return n >= n_seqs;
 }
 
-bool llama_kv_cache_recurrent::get_can_shift() const {
+bool llama_memory_recurrent::get_can_shift() const {
     // shifting the pos is trivial for recurrent models
     return true;
 }
 
-size_t llama_kv_cache_recurrent::total_size() const {
+size_t llama_memory_recurrent::total_size() const {
     size_t size = 0;
     for (const auto & buf : bufs) {
         size += ggml_backend_buffer_get_size(buf.get());
@@ -640,27 +643,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
     return size;
 }
 
-size_t llama_kv_cache_recurrent::size_k_bytes() const {
-    size_t size_k_bytes = 0;
+size_t llama_memory_recurrent::size_r_bytes() const {
+    size_t size_r_bytes = 0;
 
-    for (const auto & k : k_l) {
-        size_k_bytes += ggml_nbytes(k);
+    for (const auto & r : r_l) {
+        if (r != nullptr) {
+            size_r_bytes += ggml_nbytes(r);
+        }
     }
 
-    return size_k_bytes;
+    return size_r_bytes;
 }
 
-size_t llama_kv_cache_recurrent::size_v_bytes() const {
-    size_t size_v_bytes = 0;
+size_t llama_memory_recurrent::size_s_bytes() const {
+    size_t size_s_bytes = 0;
 
-    for (const auto & v : v_l) {
-        size_v_bytes += ggml_nbytes(v);
+    for (const auto & s : s_l) {
+        if (s != nullptr) {
+            size_s_bytes += ggml_nbytes(s);
+        }
     }
 
-    return size_v_bytes;
+    return size_s_bytes;
 }
 
-void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
     std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
     uint32_t cell_count = 0;
 
@@ -698,7 +705,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
     state_write_data(io, cell_ranges);
 }
 
-void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
     uint32_t cell_count;
     io.read_to(&cell_count, sizeof(cell_count));
 
@@ -717,7 +724,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
     }
 }
 
-void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
+void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
     for (const auto & range : cell_ranges) {
         for (uint32_t i = range.first; i < range.second; ++i) {
             const auto & cell = cells[i];
@@ -736,87 +743,85 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
     }
 }
 
-void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
-    const uint32_t v_trans = 0;
+void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
+    const uint32_t s_trans = 0;
     const uint32_t n_layer = hparams.n_layer;
 
-    io.write(&v_trans, sizeof(v_trans));
-    io.write(&n_layer, sizeof(n_layer));
+    io.write(&s_trans, sizeof(s_trans));
+    io.write(&n_layer,   sizeof(n_layer));
 
     std::vector<uint8_t> tmp_buf;
 
     // Iterate and write all the keys first, each row is a cell
     // Get whole range at a time
     for (uint32_t il = 0; il < n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 
         // Write key type
-        const int32_t k_type_i = (int32_t)k_l[il]->type;
-        io.write(&k_type_i, sizeof(k_type_i));
+        const int32_t r_type_i = (int32_t)r_l[il]->type;
+        io.write(&r_type_i, sizeof(r_type_i));
 
         // Write row size of key
-        const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
-        io.write(&k_size_row, sizeof(k_size_row));
+        const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
+        io.write(&r_size_row, sizeof(r_size_row));
 
         // Read each range of cells of k_size length each into tmp_buf and write out
         for (const auto & range : cell_ranges) {
             const size_t range_size = range.second - range.first;
-            const size_t buf_size = range_size * k_size_row;
-            io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
+            const size_t buf_size = range_size * r_size_row;
+            io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
         }
     }
 
-    if (!v_trans) {
+    if (!s_trans) {
         for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
             // Write value type
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
-            io.write(&v_type_i, sizeof(v_type_i));
+            const int32_t s_type_i = (int32_t)s_l[il]->type;
+            io.write(&s_type_i, sizeof(s_type_i));
 
             // Write row size of value
-            const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
-            io.write(&v_size_row, sizeof(v_size_row));
+            const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
+            io.write(&s_size_row, sizeof(s_size_row));
 
-            // Read each range of cells of v_size length each into tmp_buf and write out
+            // Read each range of cells of s_size length each into tmp_buf and write out
             for (const auto & range : cell_ranges) {
                 const size_t range_size = range.second - range.first;
-                const size_t buf_size = range_size * v_size_row;
-                io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
+                const size_t buf_size = range_size * s_size_row;
+                io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
             }
         }
     } else {
         // When v is transposed, we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = size;
+        const uint32_t mem_size = size;
         for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_s = hparams.n_embd_s();
 
             // Write value type
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
-            io.write(&v_type_i, sizeof(v_type_i));
+            const int32_t s_type_i = (int32_t)s_l[il]->type;
+            io.write(&s_type_i, sizeof(s_type_i));
 
             // Write element size
-            const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
-            io.write(&v_size_el, sizeof(v_size_el));
+            const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
+            io.write(&s_size_el, sizeof(s_size_el));
 
             // Write GQA embedding size
-            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+            io.write(&n_embd_s, sizeof(n_embd_s));
 
             // For each row, we get the element values of each cell
-            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+            for (uint32_t j = 0; j < n_embd_s; ++j) {
                 // Read each range of cells of v_size_el length each into tmp_buf and write out
                 for (const auto & range : cell_ranges) {
                     const size_t range_size = range.second - range.first;
-                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                    const size_t buf_size = range_size * v_size_el;
-                    io.write_tensor(v_l[il], src_offset, buf_size);
+                    const size_t src_offset = (range.first + j * mem_size) * s_size_el;
+                    const size_t buf_size = range_size * s_size_el;
+                    io.write_tensor(s_l[il], src_offset, buf_size);
                 }
             }
         }
     }
 }
 
-bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
+bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
     if (dest_seq_id != -1) {
         // single sequence
 
@@ -869,7 +874,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
         clear(true);
 
         for (uint32_t i = 0; i < cell_count; ++i) {
-            kv_cell & cell = cells[i];
+            auto & cell = cells[i];
 
             llama_pos pos;
             uint32_t  n_seq_id;
@@ -883,7 +888,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
                 llama_seq_id seq_id;
                 io.read_to(&seq_id, sizeof(seq_id));
 
-                // TODO: llama_kv_cache_recurrent should have a notion of max sequences
+                // TODO: llama_memory_recurrent should have a notion of max sequences
                 //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
                 if (seq_id < 0) {
                     //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
@@ -915,10 +920,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
     return true;
 }
 
-bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
-    uint32_t v_trans;
+bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
+    uint32_t s_trans;
     uint32_t n_layer;
-    io.read_to(&v_trans, sizeof(v_trans));
+    io.read_to(&s_trans, sizeof(s_trans));
     io.read_to(&n_layer, sizeof(n_layer));
 
     if (n_layer != hparams.n_layer) {
@@ -929,102 +934,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
         LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
         return false;
     }
-    if (false != (bool) v_trans) {
-        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+    if (false != (bool) s_trans) {
+        LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
         return false;
     }
 
     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
     for (uint32_t il = 0; il < n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 
         // Read type of key
-        int32_t k_type_i_ref;
-        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-        const int32_t k_type_i = (int32_t) k_l[il]->type;
-        if (k_type_i != k_type_i_ref) {
-            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+        int32_t r_type_i_ref;
+        io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
+        const int32_t r_type_i = (int32_t) r_l[il]->type;
+        if (r_type_i != r_type_i_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
             return false;
         }
 
         // Read row size of key
-        uint64_t k_size_row_ref;
-        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
-        if (k_size_row != k_size_row_ref) {
-            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+        uint64_t r_size_row_ref;
+        io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
+        const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
+        if (r_size_row != r_size_row_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
             return false;
         }
 
         if (cell_count) {
             // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+            ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
         }
     }
 
-    if (!v_trans) {
+    if (!s_trans) {
         for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
             // Read type of value
-            int32_t v_type_i_ref;
-            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+            int32_t s_type_i_ref;
+            io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
+            const int32_t s_type_i = (int32_t)s_l[il]->type;
+            if (s_type_i != s_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
                 return false;
             }
 
             // Read row size of value
-            uint64_t v_size_row_ref;
-            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-            const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
-            if (v_size_row != v_size_row_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+            uint64_t s_size_row_ref;
+            io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
+            const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
+            if (s_size_row != s_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
                 return false;
             }
 
             if (cell_count) {
                 // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+                ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
             }
         }
     } else {
         // For each layer, read the values for each cell (transposed)
         for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_s = hparams.n_embd_s();
 
             // Read type of value
-            int32_t v_type_i_ref;
-            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+            int32_t s_type_i_ref;
+            io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
+            const int32_t s_type_i = (int32_t)s_l[il]->type;
+            if (s_type_i != s_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
                 return false;
             }
 
             // Read element size of value
-            uint32_t v_size_el_ref;
-            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-            const size_t v_size_el = ggml_type_size(v_l[il]->type);
-            if (v_size_el != v_size_el_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+            uint32_t s_size_el_ref;
+            io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
+            const size_t s_size_el = ggml_type_size(s_l[il]->type);
+            if (s_size_el != s_size_el_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
                 return false;
             }
 
-            // Read GQA embedding size
-            uint32_t n_embd_v_gqa_ref;
-            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
-            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+            // Read state embedding size
+            uint32_t n_embd_s_ref;
+            io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
+            if (n_embd_s != n_embd_s_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
                 return false;
             }
 
             if (cell_count) {
                 // For each row in the transposed matrix, read the values for the whole cell range
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (head + j * size) * v_size_el;
-                    ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                for (uint32_t j = 0; j < n_embd_s; ++j) {
+                    const size_t dst_offset = (head + j * size) * s_size_el;
+                    ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
                 }
             }
         }
@@ -1034,25 +1037,23 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
 }
 
 //
-// llama_kv_cache_recurrent_state
+// llama_memory_recurrent_state
 //
 
-llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
+llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
 
-llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
-        llama_memory_status status,
-        llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
+llama_memory_recurrent_state::llama_memory_recurrent_state(
+        llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
 }
 
-llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
-        llama_memory_status status,
-        llama_kv_cache_recurrent * kv,
+llama_memory_recurrent_state::llama_memory_recurrent_state(
+        llama_memory_recurrent * mem,
         llama_sbatch sbatch,
-        std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
 
-llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
+llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
 
-bool llama_kv_cache_recurrent_state::next() {
+bool llama_memory_recurrent_state::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     if (++i_next >= ubatches.size()) {
@@ -1062,54 +1063,54 @@ bool llama_kv_cache_recurrent_state::next() {
     return true;
 }
 
-bool llama_kv_cache_recurrent_state::apply() {
+bool llama_memory_recurrent_state::apply() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    kv->find_slot(ubatches[i_next]);
+    mem->find_slot(ubatches[i_next]);
 
     return true;
 }
 
-std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
+std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return sbatch.out_ids;
 }
 
-llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
+llama_memory_status llama_memory_recurrent_state::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
+const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_next];
 }
 
-uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
-    return is_full ? kv->size : kv->n;
+uint32_t llama_memory_recurrent_state::get_n_rs() const {
+    return is_full ? mem->size : mem->n;
 }
 
-uint32_t llama_kv_cache_recurrent_state::get_head() const {
-    return is_full ? 0 : kv->head;
+uint32_t llama_memory_recurrent_state::get_head() const {
+    return is_full ? 0 : mem->head;
 }
 
-int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
-    return is_full ? 0 : kv->rs_z;
+int32_t llama_memory_recurrent_state::get_rs_z() const {
+    return is_full ? 0 : mem->rs_z;
 }
 
-uint32_t llama_kv_cache_recurrent_state::get_size() const {
-    return kv->size;
+uint32_t llama_memory_recurrent_state::get_size() const {
+    return mem->size;
 }
 
-ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
-    return kv->k_l[il];
+ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
+    return mem->r_l[il];
 }
 
-ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
-    return kv->v_l[il];
+ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
+    return mem->s_l[il];
 }
 
-int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
-    return  kv->cells[i + kv->head].src0;
+int32_t llama_memory_recurrent_state::s_copy(int i) const {
+    return  mem->cells[i + mem->head].src0;
 }
diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-memory-recurrent.h
similarity index 72%
rename from src/llama-kv-cache-recurrent.h
rename to src/llama-memory-recurrent.h
index f9b01a6513393..290cc84ab3fbc 100644
--- a/src/llama-kv-cache-recurrent.h
+++ b/src/llama-memory-recurrent.h
@@ -8,22 +8,27 @@
 #include <vector>
 
 //
-// llama_kv_cache_recurrent
+// llama_memory_recurrent
 //
 
-// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
+// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
 //       see the implementation of llama_kv_cache_unified_state_i for an example how to do it
-class llama_kv_cache_recurrent : public llama_memory_i {
+class llama_memory_recurrent : public llama_memory_i {
 public:
-    llama_kv_cache_recurrent(
-            const llama_model & model,
-                    ggml_type   type_k,
-                    ggml_type   type_v,
-                         bool   offload,
-                     uint32_t   kv_size,
-                     uint32_t   n_seq_max);
 
-    ~llama_kv_cache_recurrent() = default;
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
+    llama_memory_recurrent(
+            const llama_model &  model,
+              layer_filter_cb && filter,
+                    ggml_type    type_r,
+                    ggml_type    type_s,
+                         bool    offload,
+                     uint32_t    mem_size,
+                     uint32_t    n_seq_max);
+
+    ~llama_memory_recurrent() = default;
 
     //
     // llama_memory_i
@@ -51,7 +56,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
 
     bool prepare(const std::vector<llama_ubatch> & ubatches);
 
-    // find a contiguous slot of kv cells and emplace the ubatch there
+    // find a contiguous slot of memory cells and emplace the ubatch there
     bool find_slot(const llama_ubatch & ubatch);
 
     bool get_can_shift() const override;
@@ -72,7 +77,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
     int32_t rs_z = -1;
 
     // TODO: optimize for recurrent state needs
-    struct kv_cell {
+    struct mem_cell {
         llama_pos pos  = -1;
         int32_t   src  = -1; // used to know where states should be copied from
         int32_t   src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
@@ -88,15 +93,16 @@ class llama_kv_cache_recurrent : public llama_memory_i {
             return seq_id.empty();
         }
 
-        bool is_same_seq(const kv_cell & other) const {
+        bool is_same_seq(const mem_cell & other) const {
             return seq_id == other.seq_id;
         }
     };
 
-    std::vector<kv_cell> cells;
+    std::vector<mem_cell> cells;
 
-    std::vector<ggml_tensor *> k_l; // per layer
-    std::vector<ggml_tensor *> v_l;
+    // per layer
+    std::vector<ggml_tensor *> r_l;
+    std::vector<ggml_tensor *> s_l;
 
 private:
     //const llama_model & model;
@@ -109,8 +115,8 @@ class llama_kv_cache_recurrent : public llama_memory_i {
 
     size_t total_size() const;
 
-    size_t size_k_bytes() const;
-    size_t size_v_bytes() const;
+    size_t size_r_bytes() const;
+    size_t size_s_bytes() const;
 
     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -119,24 +125,22 @@ class llama_kv_cache_recurrent : public llama_memory_i {
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
-class llama_kv_cache_recurrent_state : public llama_memory_state_i {
+class llama_memory_recurrent_state : public llama_memory_state_i {
 public:
     // used for errors
-    llama_kv_cache_recurrent_state(llama_memory_status status);
+    llama_memory_recurrent_state(llama_memory_status status);
 
     // used to create a full-cache state
-    llama_kv_cache_recurrent_state(
-            llama_memory_status status,
-            llama_kv_cache_recurrent * kv);
+    llama_memory_recurrent_state(
+            llama_memory_recurrent * mem);
 
     // used to create a state from a batch
-    llama_kv_cache_recurrent_state(
-            llama_memory_status status,
-            llama_kv_cache_recurrent * kv,
+    llama_memory_recurrent_state(
+            llama_memory_recurrent * mem,
             llama_sbatch sbatch,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_kv_cache_recurrent_state();
+    virtual ~llama_memory_recurrent_state();
 
     //
     // llama_memory_state_i
@@ -151,23 +155,23 @@ class llama_kv_cache_recurrent_state : public llama_memory_state_i {
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_kv_cache_recurrent_state specific API
+    // llama_memory_recurrent_state specific API
     //
 
-    uint32_t get_n_kv() const;
+    uint32_t get_n_rs() const;
     uint32_t get_head() const;
     int32_t  get_rs_z() const;
     uint32_t get_size() const;
 
-    ggml_tensor * get_k_l(int32_t il) const;
-    ggml_tensor * get_v_l(int32_t il) const;
+    ggml_tensor * get_r_l(int32_t il) const;
+    ggml_tensor * get_s_l(int32_t il) const;
 
     int32_t s_copy(int i) const;
 
 private:
     const llama_memory_status status;
 
-    llama_kv_cache_recurrent * kv;
+    llama_memory_recurrent * mem;
 
     llama_sbatch sbatch;
 
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index a5eb122f998d8..a5853f8b12dc0 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -8,7 +8,8 @@
 
 #include "llama-kv-cache-unified.h"
 #include "llama-kv-cache-unified-iswa.h"
-#include "llama-kv-cache-recurrent.h"
+#include "llama-memory-hybrid.h"
+#include "llama-memory-recurrent.h"
 
 #include "ggml-cpp.h"
 
@@ -470,6 +471,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
     std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
     std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
     std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
+    std::fill(
+        hparams.recurrent_layer_arr.begin(),
+        hparams.recurrent_layer_arr.end(),
+        llm_arch_is_recurrent(ml.get_arch()));
 
     std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
 
@@ -9111,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context {
         // {n_embd, n_tokens}
         inpL = build_inp_embd(model.tok_embd);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -9120,7 +9125,7 @@ struct llm_build_mamba : public llm_graph_context {
                     LLM_NORM_RMS, il);
             cb(cur, "attn_norm", il);
 
-            cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
+            cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
 
             if (il == n_layer - 1) {
                 // skip computing output for unused tokens
@@ -9158,12 +9163,12 @@ struct llm_build_mamba : public llm_graph_context {
 
     // TODO: split
     ggml_tensor * build_mamba_layer(
-             ggml_cgraph * gf,
-             ggml_tensor * cur,
-             ggml_tensor * state_copy,
-      const llama_ubatch & ubatch,
-                     int   il) const {
-        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+        llm_graph_input_rs * inp,
+               ggml_cgraph * gf,
+               ggml_tensor * cur,
+        const llama_ubatch & ubatch,
+                       int   il) const {
+        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
         const auto kv_head = kv_state->get_head();
 
@@ -9183,17 +9188,17 @@ struct llm_build_mamba : public llm_graph_context {
         GGML_ASSERT(ubatch.equal_seqs);
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
-        ggml_tensor * conv_states_all = kv_state->get_k_l(il);
-        ggml_tensor * ssm_states_all  = kv_state->get_v_l(il);
+        ggml_tensor * conv_states_all = kv_state->get_r_l(il);
+        ggml_tensor * ssm_states_all  = kv_state->get_s_l(il);
 
         // (ab)using the KV cache to store the states
-        ggml_tensor * conv = build_recurrent_state(
-                gf, conv_states_all, state_copy,
-                hparams.n_embd_k_s(), n_seqs);
+        ggml_tensor * conv = build_rs(
+                inp, gf, conv_states_all,
+                hparams.n_embd_r(), n_seqs);
         conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
-        ggml_tensor * ssm = build_recurrent_state(
-                gf, ssm_states_all, state_copy,
-                hparams.n_embd_v_s(), n_seqs);
+        ggml_tensor * ssm = build_rs(
+                inp, gf, ssm_states_all,
+                hparams.n_embd_s(), n_seqs);
         ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
 
         // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -11904,13 +11909,13 @@ struct llm_build_rwkv6_base : public llm_graph_context {
     }
 
     ggml_tensor * build_rwkv6_time_mix(
+            llm_graph_input_rs * inp,
             ggml_cgraph * gf,
             ggml_tensor * cur,
             ggml_tensor * x_prev,
-            ggml_tensor * state_copy,
             const llama_ubatch & ubatch,
             int   il) const {
-        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12031,9 +12036,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
         }
 
-        ggml_tensor * wkv_state = build_recurrent_state(
-                gf, kv_state->get_v_l(il), state_copy,
-                hparams.n_embd_v_s(), n_seqs);
+        ggml_tensor * wkv_state = build_rs(
+                inp, gf, kv_state->get_s_l(il),
+                hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output;
         if (is_qrwkv) {
@@ -12051,9 +12056,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_state->get_v_l(il),
-                        hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
+                        kv_state->get_s_l(il),
+                        hparams.n_embd_s() * n_seqs,
+                        hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
                         )
                     )
                 );
@@ -12087,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
         inpL = build_inp_embd(model.tok_embd);
         inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12097,9 +12102,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, ubatch, il
-                    );
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
             ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12114,7 +12117,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
+            cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -12177,14 +12180,14 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
 // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
 struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
     llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
-        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+        GGML_ASSERT(n_embd == hparams.n_embd_r());
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
 
         inpL = build_inp_embd(model.tok_embd);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12194,9 +12197,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, ubatch, il
-                    );
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
             cb(att_norm, "attn_norm", il);
@@ -12208,7 +12209,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
+            cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12296,14 +12297,14 @@ struct llm_build_rwkv7_base : public llm_graph_context {
     }
 
     ggml_tensor * build_rwkv7_time_mix(
+            llm_graph_input_rs * inp,
             ggml_cgraph * gf,
             ggml_tensor * cur,
             ggml_tensor * x_prev,
-            ggml_tensor * state_copy,
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
-        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12382,9 +12383,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
         a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
 
-        ggml_tensor * wkv_state = build_recurrent_state(
-                gf, kv_state->get_v_l(il), state_copy,
-                hparams.n_embd_v_s(), n_seqs);
+        ggml_tensor * wkv_state = build_rs(
+                inp, gf, kv_state->get_s_l(il),
+                hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
         cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
@@ -12397,9 +12398,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_state->get_v_l(il),
-                        hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
+                        kv_state->get_s_l(il),
+                        hparams.n_embd_s() * n_seqs,
+                        hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
                         )
                     )
                 );
@@ -12440,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
         inpL = build_inp_embd(model.tok_embd);
         inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12450,9 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, ubatch, il
-                    );
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
             ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12467,7 +12466,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -12525,7 +12524,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
 
 struct llm_build_arwkv7 : public llm_build_rwkv7_base {
     llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
-        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+        GGML_ASSERT(n_embd == hparams.n_embd_r());
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
@@ -12533,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12543,9 +12542,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, ubatch, il
-                    );
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
             cb(att_norm, "attn_norm", il);
@@ -12557,7 +12554,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13738,6 +13735,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
     llama_memory_i * res;
 
     switch (arch) {
+        // Models that need specific instantiation should be handled in the
+        // switch statement
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
@@ -13747,57 +13746,75 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
             {
                 res = nullptr;
             } break;
-        case LLM_ARCH_MAMBA:
-        case LLM_ARCH_RWKV6:
-        case LLM_ARCH_RWKV6QWEN2:
-        case LLM_ARCH_RWKV7:
-        case LLM_ARCH_ARWKV7:
-            {
-                res = new llama_kv_cache_recurrent(
-                        *this,
-                        GGML_TYPE_F32,
-                        GGML_TYPE_F32,
-                        cparams.offload_kqv,
-                        std::max((uint32_t) 1, cparams.n_seq_max),
-                        cparams.n_seq_max);
-            } break;
+        // Models that need standard caching should rely on recurrent/hybrid
+        // checks
         default:
             {
-                const auto padding = llama_kv_cache_unified::get_padding(cparams);
-
-                cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
-
-                LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
-
-                if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
-                    GGML_ASSERT(hparams.is_swa_any());
-
-                    res = new llama_kv_cache_unified_iswa(
-                            *this,
-                            params.type_k,
-                            params.type_v,
-                            !cparams.flash_attn,
-                            cparams.offload_kqv,
-                            params.swa_full,
-                            cparams.n_ctx,
-                            cparams.n_seq_max,
-                            cparams.n_ubatch,
-                            padding);
-                } else {
-                    GGML_ASSERT(!hparams.is_swa_any());
-
-                    res = new llama_kv_cache_unified(
+                if (llm_arch_is_recurrent(arch)) {
+                    res = new llama_memory_recurrent(
                             *this,
                             nullptr,
-                            params.type_k,
-                            params.type_v,
-                            !cparams.flash_attn,
+                            GGML_TYPE_F32,
+                            GGML_TYPE_F32,
                             cparams.offload_kqv,
-                            cparams.n_ctx,
-                            cparams.n_seq_max,
-                            padding,
-                            hparams.n_swa,
-                            hparams.swa_type);
+                            std::max((uint32_t) 1, cparams.n_seq_max),
+                            cparams.n_seq_max);
+                } else if (llm_arch_is_hybrid(arch)) {
+                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+
+                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
+
+                    res = new llama_memory_hybrid(
+                        /* model             */ *this,
+                        /* attn_type_k       */ params.type_k,
+                        /* attn_type_v       */ params.type_v,
+                        /* attn_v_trans      */ !cparams.flash_attn,
+                        /* attn_kv_size      */ cparams.n_ctx,
+                        /* attn_n_pad        */ padding,
+                        /* attn_n_swa        */ hparams.n_swa,
+                        /* attn_swa_type     */ hparams.swa_type,
+                        /* recurrent_type_k  */ GGML_TYPE_F32,
+                        /* recurrent_type_v  */ GGML_TYPE_F32,
+                        /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
+                        /* n_seq_max         */ cparams.n_seq_max,
+                        /* offload           */ cparams.offload_kqv);
+                } else {
+                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+
+                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
+
+                    LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
+
+                    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+                        GGML_ASSERT(hparams.is_swa_any());
+
+                        res = new llama_kv_cache_unified_iswa(
+                                *this,
+                                params.type_k,
+                                params.type_v,
+                                !cparams.flash_attn,
+                                cparams.offload_kqv,
+                                params.swa_full,
+                                cparams.n_ctx,
+                                cparams.n_seq_max,
+                                cparams.n_ubatch,
+                                padding);
+                    } else {
+                        GGML_ASSERT(!hparams.is_swa_any());
+
+                        res = new llama_kv_cache_unified(
+                                *this,
+                                nullptr,
+                                params.type_k,
+                                params.type_v,
+                                !cparams.flash_attn,
+                                cparams.offload_kqv,
+                                cparams.n_ctx,
+                                cparams.n_seq_max,
+                                padding,
+                                hparams.n_swa,
+                                hparams.swa_type);
+                    }
                 }
             }
     }
@@ -14377,14 +14394,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
 }
 
 bool llama_model_is_recurrent(const llama_model * model) {
-    switch (model->arch) {
-        case     LLM_ARCH_MAMBA:      return true;
-        case     LLM_ARCH_RWKV6:      return true;
-        case     LLM_ARCH_RWKV6QWEN2: return true;
-        case     LLM_ARCH_RWKV7:      return true;
-        case     LLM_ARCH_ARWKV7:     return true;
-        default: return false;
-    }
+    return llm_arch_is_recurrent(model->arch);
 }
 
 const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {