Skip to content

Commit de55333

Browse files
authored
Update Triton extend backend interface (#3309)
1 parent 7aad8d1 commit de55333

File tree

5 files changed

+427
-69
lines changed

5 files changed

+427
-69
lines changed

python/sglang/srt/layers/attention/double_sparsity_backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
1717
def __init__(self, model_runner: ModelRunner):
1818
# Lazy import to avoid the initialization of cuda context
1919
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
20+
extend_attention_fwd,
2021
flash_decode_attention_fwd,
2122
flash_decode_sparse_attention_fwd,
2223
)
23-
from sglang.srt.layers.attention.triton_ops.extend_attention import (
24-
extend_attention_fwd,
25-
)
2624

2725
super().__init__()
2826

python/sglang/srt/layers/attention/triton_backend.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def __init__(self, model_runner: ModelRunner):
3737
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
3838
)
3939
self.req_to_token = model_runner.req_to_token_pool.req_to_token
40+
self.qo_indptr = torch.zeros(
41+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
42+
)
4043

4144
self.num_head = (
4245
model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -54,6 +57,9 @@ def __init__(self, model_runner: ModelRunner):
5457
def init_forward_metadata(self, forward_batch: ForwardBatch):
5558
"""Init auxiliary variables for triton attention backend."""
5659

60+
bs = forward_batch.batch_size
61+
kv_indptr = self.kv_indptr
62+
5763
if forward_batch.forward_mode.is_decode():
5864
attn_logits = torch.empty(
5965
(
@@ -68,31 +74,59 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
6874

6975
max_extend_len = None
7076

71-
kv_indptr = self.kv_indptr
72-
bs = len(forward_batch.req_pool_indices)
7377
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
7478
kv_indptr = kv_indptr[: bs + 1]
7579
kv_indices = torch.empty(
76-
forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
80+
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
7781
)
7882
create_flashinfer_kv_indices_triton[(bs,)](
79-
forward_batch.req_to_token_pool.req_to_token,
83+
self.req_to_token,
8084
forward_batch.req_pool_indices,
8185
forward_batch.seq_lens,
8286
kv_indptr,
8387
None,
8488
kv_indices,
85-
forward_batch.req_to_token_pool.req_to_token.stride(0),
89+
self.req_to_token.stride(0),
8690
)
8791

92+
qo_indptr = None
93+
custom_mask = None
8894
else:
95+
kv_indptr[1 : bs + 1] = torch.cumsum(
96+
forward_batch.extend_prefix_lens, dim=0
97+
)
98+
kv_indptr = kv_indptr[: bs + 1]
99+
kv_indices = torch.empty(
100+
forward_batch.extend_prefix_lens.sum().item(),
101+
dtype=torch.int32,
102+
device=self.device,
103+
)
104+
create_flashinfer_kv_indices_triton[(bs,)](
105+
self.req_to_token,
106+
forward_batch.req_pool_indices,
107+
forward_batch.extend_prefix_lens,
108+
kv_indptr,
109+
None,
110+
kv_indices,
111+
self.req_to_token.stride(0),
112+
)
113+
114+
qo_indptr = self.qo_indptr
115+
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
116+
qo_indptr = qo_indptr[: bs + 1]
117+
custom_mask = None
118+
89119
attn_logits = None
90120
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
91121

92-
kv_indptr = None
93-
kv_indices = None
94-
95-
self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
122+
self.forward_metadata = (
123+
attn_logits,
124+
max_extend_len,
125+
kv_indptr,
126+
kv_indices,
127+
qo_indptr,
128+
custom_mask,
129+
)
96130

97131
def init_cuda_graph_state(self, max_bs: int):
98132
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
@@ -144,6 +178,8 @@ def init_forward_metadata_capture_cuda_graph(
144178
None,
145179
kv_indptr,
146180
kv_indices,
181+
None,
182+
None,
147183
)
148184

149185
def init_forward_metadata_replay_cuda_graph(
@@ -197,19 +233,19 @@ def forward_extend(
197233
layer, forward_batch.out_cache_loc, k, v
198234
)
199235

200-
_, max_extend_len, _, _ = self.forward_metadata
236+
_, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = (
237+
self.forward_metadata
238+
)
201239
self.extend_attention_fwd(
202240
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
203241
k.contiguous(),
204242
v.contiguous(),
205243
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
206244
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
207245
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
208-
forward_batch.req_to_token_pool.req_to_token,
209-
forward_batch.req_pool_indices,
210-
forward_batch.seq_lens,
211-
forward_batch.extend_seq_lens,
212-
forward_batch.extend_start_loc,
246+
qo_indptr,
247+
kv_indptr,
248+
kv_indices,
213249
max_extend_len,
214250
layer.scaling,
215251
layer.logit_cap,
@@ -235,7 +271,7 @@ def forward_decode(
235271
else:
236272
o = torch.empty_like(q)
237273

238-
attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
274+
attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata
239275

240276
if save_kv_cache:
241277
forward_batch.token_to_kv_pool.set_kv_buffer(

0 commit comments

Comments
 (0)