Skip to content

Commit 57c998b

Browse files
committed
rebase all codes
1 parent 8f527e2 commit 57c998b

File tree

7 files changed

+964
-2
lines changed

7 files changed

+964
-2
lines changed

python/sglang/srt/layers/attention_backend.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,185 @@ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
478478
layer.logit_cap,
479479
)
480480
return o
481+
482+
483+
class DoubleSparseAttnBackend(AttentionBackend):
484+
def __init__(self, model_runner: ModelRunner):
485+
# Lazy import to avoid the initialization of cuda context
486+
from sglang.srt.layers.triton_attention.decode_attention import (
487+
decode_attention_fwd,
488+
)
489+
from sglang.srt.layers.triton_attention.extend_attention import (
490+
extend_attention_fwd,
491+
)
492+
from sglang.srt.layers.triton_attention.sparse_decode_attention import (
493+
decode_sparse_attention_fwd
494+
)
495+
496+
super().__init__()
497+
498+
self.decode_attention_fwd = decode_attention_fwd
499+
self.decode_sparse_attention_fwd = decode_sparse_attention_fwd
500+
self.extend_attention_fwd = extend_attention_fwd
501+
self.num_head = model_runner.model_config.num_attention_heads
502+
503+
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
504+
self.reduce_dtype = torch.float32
505+
else:
506+
self.reduce_dtype = torch.float16
507+
508+
self.forward_metadata = None
509+
510+
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
511+
512+
def init_forward_metadata(
513+
self, batch: ScheduleBatch, input_metadata: InputMetadata
514+
):
515+
"""Init auxiliary variables for triton attention backend."""
516+
517+
if input_metadata.forward_mode.is_decode():
518+
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
519+
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
520+
521+
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
522+
attn_logits = torch.empty(
523+
(self.num_head, total_num_tokens),
524+
dtype=self.reduce_dtype,
525+
device="cuda",
526+
)
527+
528+
max_seq_len = torch.max(input_metadata.seq_lens).item()
529+
max_extend_len = None
530+
#NOTE: Align sequence order with req_to_token order
531+
ds_req_to_token = input_metadata.req_to_token_pool.req_to_token[input_metadata.req_pool_indices]
532+
else:
533+
start_loc = attn_logits = max_seq_len = None
534+
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
535+
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
536+
ds_req_to_token = None
537+
538+
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len, ds_req_to_token
539+
540+
def init_cuda_graph_state(self, max_bs: int):
541+
#TODO(Andy): Support CUDA graph for double sparse attention
542+
raise ValueError("Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph")
543+
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
544+
545+
self.cuda_graph_start_loc = torch.zeros(
546+
(max_bs,), dtype=torch.int32, device="cuda"
547+
)
548+
self.cuda_graph_attn_logits = torch.empty(
549+
(
550+
self.num_head,
551+
self.cuda_graph_max_total_num_tokens,
552+
),
553+
dtype=self.reduce_dtype,
554+
device="cuda",
555+
)
556+
557+
def init_forward_metadata_capture_cuda_graph(
558+
self, bs: int, req_pool_indices, seq_lens
559+
):
560+
self.forward_metadata = (
561+
self.cuda_graph_start_loc,
562+
self.cuda_graph_attn_logits,
563+
self.cuda_graph_max_seq_len,
564+
None,
565+
)
566+
567+
def init_forward_metadata_replay_cuda_graph(
568+
self, bs: int, req_pool_indices, seq_lens
569+
):
570+
self.cuda_graph_start_loc.zero_()
571+
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
572+
573+
def get_cuda_graph_seq_len_fill_value(self):
574+
return 1
575+
576+
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
577+
# TODO: reuse the buffer across layers
578+
if layer.qk_head_dim != layer.v_head_dim:
579+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
580+
else:
581+
o = torch.empty_like(q)
582+
583+
k_label = torch.gather(k, 2, input_metadata.sorted_channels[layer.layer_id].unsqueeze(0).expand(k.shape[0], -1, -1))
584+
585+
input_metadata.token_to_kv_pool.set_kv_buffer(
586+
layer.layer_id, input_metadata.out_cache_loc, k, v, k_label
587+
)
588+
589+
start_loc, attn_logits, max_seq_len, max_extend_len, ds_req_to_token = self.forward_metadata
590+
self.extend_attention_fwd(
591+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
592+
k.contiguous(),
593+
v.contiguous(),
594+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
595+
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
596+
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
597+
input_metadata.req_to_token_pool.req_to_token,
598+
input_metadata.req_pool_indices,
599+
input_metadata.seq_lens,
600+
input_metadata.extend_seq_lens,
601+
input_metadata.extend_start_loc,
602+
max_extend_len,
603+
layer.scaling,
604+
layer.logit_cap,
605+
)
606+
return o
607+
608+
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
609+
# During torch.compile, there is a bug in rotary_emb that causes the
610+
# output value to have a 3D tensor shape. This reshapes the output correctly.
611+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
612+
613+
# TODO: reuse the buffer across layers
614+
if layer.qk_head_dim != layer.v_head_dim:
615+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
616+
else:
617+
o = torch.empty_like(q)
618+
619+
start_loc, attn_logits, max_seq_len, max_extend_len, ds_req_to_token = self.forward_metadata
620+
621+
k_label = torch.gather(k, 2, input_metadata.sorted_channels[layer.layer_id].unsqueeze(0).expand(k.shape[0], -1, -1))
622+
623+
input_metadata.token_to_kv_pool.set_kv_buffer(
624+
layer.layer_id, input_metadata.out_cache_loc, k, v, k_label
625+
)
626+
627+
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
628+
# and set a minimum value for sparse_decode
629+
if max_seq_len < input_metadata.heavy_token_num or max_seq_len < input_metadata.sparse_decode_thresold:
630+
self.decode_attention_fwd(
631+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
632+
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
633+
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
634+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
635+
input_metadata.req_to_token_pool.req_to_token,
636+
input_metadata.req_pool_indices,
637+
start_loc,
638+
input_metadata.seq_lens,
639+
attn_logits,
640+
max_seq_len,
641+
layer.scaling,
642+
layer.logit_cap,
643+
)
644+
else:
645+
#TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
646+
q_label = torch.gather(q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), 2, input_metadata.sorted_channels[layer.layer_id].unsqueeze(0).expand(q.shape[0], -1, -1))
647+
self.decode_sparse_attention_fwd(
648+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
649+
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
650+
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
651+
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
652+
q_label,
653+
input_metadata.token_to_kv_pool.get_label_buffer(layer.layer_id),
654+
ds_req_to_token,
655+
input_metadata.seq_lens,
656+
max_seq_len,
657+
layer.scaling,
658+
layer.logit_cap,
659+
input_metadata.heavy_token_num,
660+
)
661+
662+
return o

0 commit comments

Comments
 (0)