Skip to content

Update Triton extend backend interface #3309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
extend_attention_fwd,
flash_decode_attention_fwd,
flash_decode_sparse_attention_fwd,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
)

super().__init__()

Expand Down
68 changes: 52 additions & 16 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, model_runner: ModelRunner):
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)

self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
Expand All @@ -54,6 +57,9 @@ def __init__(self, model_runner: ModelRunner):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""

bs = forward_batch.batch_size
kv_indptr = self.kv_indptr

if forward_batch.forward_mode.is_decode():
attn_logits = torch.empty(
(
Expand All @@ -68,31 +74,59 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

max_extend_len = None

kv_indptr = self.kv_indptr
bs = len(forward_batch.req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
forward_batch.req_to_token_pool.req_to_token,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
kv_indptr,
None,
kv_indices,
forward_batch.req_to_token_pool.req_to_token.stride(0),
self.req_to_token.stride(0),
)

qo_indptr = None
custom_mask = None
else:
kv_indptr[1 : bs + 1] = torch.cumsum(
forward_batch.extend_prefix_lens, dim=0
)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(),
dtype=torch.int32,
device=self.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.extend_prefix_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)

qo_indptr = self.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None

attn_logits = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()

kv_indptr = None
kv_indices = None

self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
self.forward_metadata = (
attn_logits,
max_extend_len,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask,
)

def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
Expand Down Expand Up @@ -144,6 +178,8 @@ def init_forward_metadata_capture_cuda_graph(
None,
kv_indptr,
kv_indices,
None,
None,
)

def init_forward_metadata_replay_cuda_graph(
Expand Down Expand Up @@ -197,19 +233,19 @@ def forward_extend(
layer, forward_batch.out_cache_loc, k, v
)

_, max_extend_len, _, _ = self.forward_metadata
_, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = (
self.forward_metadata
)
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_seq_lens,
forward_batch.extend_start_loc,
qo_indptr,
kv_indptr,
kv_indices,
max_extend_len,
layer.scaling,
layer.logit_cap,
Expand All @@ -235,7 +271,7 @@ def forward_decode(
else:
o = torch.empty_like(q)

attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
Expand Down
Loading
Loading