Skip to content

Commit 799c4bb

Browse files
authored
Fuse MLA set kv cache kernel (#5748)
1 parent 02723e1 commit 799c4bb

File tree

4 files changed

+100
-9
lines changed

4 files changed

+100
-9
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ def forward_extend(
625625
save_kv_cache=True,
626626
# For multi-head latent attention
627627
q_rope: Optional[torch.Tensor] = None,
628+
k_rope: Optional[torch.Tensor] = None,
628629
):
629630
if k is not None:
630631
assert v is not None
@@ -639,11 +640,11 @@ def forward_extend(
639640
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
640641
)
641642
else:
642-
forward_batch.token_to_kv_pool.set_kv_buffer(
643+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
643644
layer,
644645
cache_loc,
645646
k,
646-
v,
647+
k_rope,
647648
)
648649

649650
# Use precomputed metadata across all layers
@@ -887,6 +888,7 @@ def forward_decode(
887888
save_kv_cache=True,
888889
# For multi-head latent attention
889890
q_rope: Optional[torch.Tensor] = None,
891+
k_rope: Optional[torch.Tensor] = None,
890892
) -> torch.Tensor:
891893
if k is not None:
892894
assert v is not None
@@ -901,11 +903,11 @@ def forward_decode(
901903
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
902904
)
903905
else:
904-
forward_batch.token_to_kv_pool.set_kv_buffer(
906+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
905907
layer,
906908
cache_loc,
907909
k,
908-
v,
910+
k_rope,
909911
)
910912

911913
# Use precomputed metadata across all layers

python/sglang/srt/layers/radix_attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,11 @@ def forward(
9292
if k is not None:
9393
# For cross-layer sharing, kv can be None
9494
assert v is not None
95-
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
96-
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
95+
if "k_rope" not in kwargs:
96+
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
97+
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
98+
else:
99+
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
97100

98101
return forward_batch.attn_backend.forward(
99102
q,

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import numpy as np
3535
import psutil
3636
import torch
37+
import triton
38+
import triton.language as tl
3739

3840
from sglang.srt.layers.radix_attention import RadixAttention
3941
from sglang.srt.utils import debug_timing, get_compiler_backend
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
405407
dst_2[loc] = src_2.to(dtype).view(store_dtype)
406408

407409

410+
@triton.jit
411+
def set_mla_kv_buffer_kernel(
412+
kv_buffer_ptr,
413+
cache_k_nope_ptr,
414+
cache_k_rope_ptr,
415+
loc_ptr,
416+
buffer_stride: tl.constexpr,
417+
nope_stride: tl.constexpr,
418+
rope_stride: tl.constexpr,
419+
nope_dim: tl.constexpr,
420+
rope_dim: tl.constexpr,
421+
BLOCK: tl.constexpr,
422+
):
423+
pid_loc = tl.program_id(0)
424+
pid_blk = tl.program_id(1)
425+
426+
base = pid_blk * BLOCK
427+
offs = base + tl.arange(0, BLOCK)
428+
total_dim = nope_dim + rope_dim
429+
mask = offs < total_dim
430+
431+
loc = tl.load(loc_ptr + pid_loc)
432+
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
433+
434+
if base + BLOCK <= nope_dim:
435+
src = tl.load(
436+
cache_k_nope_ptr + pid_loc * nope_stride + offs,
437+
mask=mask,
438+
)
439+
else:
440+
offs_rope = offs - nope_dim
441+
src = tl.load(
442+
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
443+
mask=mask,
444+
)
445+
446+
tl.store(dst_ptr, src, mask=mask)
447+
448+
449+
def set_mla_kv_buffer_triton(
450+
kv_buffer: torch.Tensor,
451+
loc: torch.Tensor,
452+
cache_k_nope: torch.Tensor,
453+
cache_k_rope: torch.Tensor,
454+
):
455+
nope_dim = cache_k_nope.shape[-1]
456+
rope_dim = cache_k_rope.shape[-1]
457+
total_dim = nope_dim + rope_dim
458+
BLOCK = 128
459+
n_loc = loc.numel()
460+
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
461+
462+
set_mla_kv_buffer_kernel[grid](
463+
kv_buffer,
464+
cache_k_nope,
465+
cache_k_rope,
466+
loc,
467+
kv_buffer.stride(0),
468+
cache_k_nope.stride(0),
469+
cache_k_rope.stride(0),
470+
nope_dim,
471+
rope_dim,
472+
BLOCK=BLOCK,
473+
)
474+
475+
408476
class MLATokenToKVPool(KVCache):
409477
def __init__(
410478
self,
@@ -504,6 +572,25 @@ def set_kv_buffer(
504572
else:
505573
self.kv_buffer[layer_id][loc] = cache_k
506574

575+
def set_mla_kv_buffer(
576+
self,
577+
layer: RadixAttention,
578+
loc: torch.Tensor,
579+
cache_k_nope: torch.Tensor,
580+
cache_k_rope: torch.Tensor,
581+
):
582+
layer_id = layer.layer_id
583+
if cache_k_nope.dtype != self.dtype:
584+
cache_k_nope = cache_k_nope.to(self.dtype)
585+
cache_k_rope = cache_k_rope.to(self.dtype)
586+
if self.store_dtype != self.dtype:
587+
cache_k_nope = cache_k_nope.view(self.store_dtype)
588+
cache_k_rope = cache_k_rope.view(self.store_dtype)
589+
590+
set_mla_kv_buffer_triton(
591+
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
592+
)
593+
507594
def get_flat_data(self, indices):
508595
# prepare a large chunk of contiguous data for efficient transfer
509596
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])

python/sglang/srt/models/deepseek_v2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -757,14 +757,13 @@ def forward_absorb(
757757

758758
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
759759

760-
k = torch.cat([k_nope, k_pe], dim=-1)
761-
762760
if self.attention_backend == "fa3":
763761
attn_output = self.attn_mqa(
764-
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
762+
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
765763
)
766764
else:
767765
q = torch.cat([q_nope_out, q_pe], dim=-1)
766+
k = torch.cat([k_nope, k_pe], dim=-1)
768767
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
769768
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
770769

0 commit comments

Comments
 (0)