Skip to content

Commit 2d61132

Browse files
authored
Support Eagle2 for Triton backend (#3466)
1 parent cddb1cd commit 2d61132

File tree

5 files changed

+286
-42
lines changed

5 files changed

+286
-42
lines changed

python/sglang/srt/layers/attention/triton_backend.py

Lines changed: 223 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Optional
44

55
import torch
6+
import triton
67

78
from sglang.srt.layers.attention import AttentionBackend
89
from sglang.srt.layers.attention.flashinfer_backend import (
@@ -18,7 +19,12 @@
1819

1920

2021
class TritonAttnBackend(AttentionBackend):
21-
def __init__(self, model_runner: ModelRunner):
22+
def __init__(
23+
self,
24+
model_runner: ModelRunner,
25+
skip_prefill: bool = False,
26+
kv_indptr_buf: Optional[torch.Tensor] = None,
27+
):
2228
# Lazy import to avoid the initialization of cuda context
2329
from sglang.srt.layers.attention.triton_ops.decode_attention import (
2430
decode_attention_fwd,
@@ -33,14 +39,25 @@ def __init__(self, model_runner: ModelRunner):
3339
self.extend_attention_fwd = extend_attention_fwd
3440

3541
max_bs = model_runner.req_to_token_pool.size
36-
self.kv_indptr = torch.zeros(
37-
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
38-
)
42+
43+
if kv_indptr_buf is None:
44+
self.kv_indptr = torch.zeros(
45+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
46+
)
47+
else:
48+
self.kv_indptr = kv_indptr_buf
49+
3950
self.req_to_token = model_runner.req_to_token_pool.req_to_token
4051
self.qo_indptr = torch.zeros(
4152
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
4253
)
4354

55+
self.mask_indptr = torch.zeros(
56+
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
57+
)
58+
59+
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
60+
4461
self.num_head = (
4562
model_runner.model_config.num_attention_heads // get_attention_tp_size()
4663
)
@@ -50,7 +67,7 @@ def __init__(self, model_runner: ModelRunner):
5067

5168
self.forward_metadata = None
5269

53-
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
70+
self.max_context_len = model_runner.model_config.context_len
5471

5572
self.device = model_runner.device
5673

@@ -59,11 +76,31 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
5976

6077
bs = forward_batch.batch_size
6178
kv_indptr = self.kv_indptr
62-
63-
if forward_batch.forward_mode.is_decode():
64-
attn_logits = torch.empty(
79+
spec_info = forward_batch.spec_info
80+
81+
if forward_batch.forward_mode.is_decode_or_idle():
82+
if spec_info is None:
83+
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
84+
kv_indptr = kv_indptr[: bs + 1]
85+
kv_indices = torch.zeros(
86+
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
87+
)
88+
create_flashinfer_kv_indices_triton[(bs,)](
89+
self.req_to_token,
90+
forward_batch.req_pool_indices,
91+
forward_batch.seq_lens,
92+
kv_indptr,
93+
None,
94+
kv_indices,
95+
self.req_to_token.stride(0),
96+
)
97+
else:
98+
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
99+
bs = kv_indptr.shape[0] - 1
100+
101+
attn_logits = torch.zeros(
65102
(
66-
forward_batch.batch_size,
103+
bs,
67104
self.num_head,
68105
self.num_kv_splits,
69106
self.v_head_dim + 1,
@@ -72,12 +109,24 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
72109
device=self.device,
73110
)
74111

112+
qo_indptr = None
113+
custom_mask = None
114+
mask_indptr = None
75115
max_extend_len = None
76-
116+
elif forward_batch.forward_mode.is_target_verify():
117+
bs = len(forward_batch.req_pool_indices)
118+
qo_indptr = torch.arange(
119+
0,
120+
(1 + bs) * self.num_draft_tokens,
121+
step=self.num_draft_tokens,
122+
dtype=torch.int32,
123+
device=self.device,
124+
)
125+
# Different with flashinfer kv_indptr and kv_indices construction
77126
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
78127
kv_indptr = kv_indptr[: bs + 1]
79-
kv_indices = torch.empty(
80-
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
128+
kv_indices = torch.zeros(
129+
kv_indptr[-1], dtype=torch.int32, device=self.device
81130
)
82131
create_flashinfer_kv_indices_triton[(bs,)](
83132
self.req_to_token,
@@ -89,15 +138,32 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
89138
self.req_to_token.stride(0),
90139
)
91140

92-
qo_indptr = None
93-
custom_mask = None
94-
mask_offsets = None
141+
custom_mask = spec_info.custom_mask
142+
seq_mask_len = self.num_draft_tokens * (
143+
forward_batch.seq_lens + self.num_draft_tokens
144+
)
145+
mask_indptr = self.mask_indptr
146+
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
147+
mask_indptr = mask_indptr[: bs + 1]
148+
max_extend_len = self.num_draft_tokens
149+
attn_logits = None
150+
elif forward_batch.forward_mode.is_draft_extend():
151+
kv_indices, kv_indptr, qo_indptr, custom_mask = (
152+
spec_info.generate_attn_arg_prefill(
153+
forward_batch.req_pool_indices,
154+
forward_batch.seq_lens,
155+
self.req_to_token,
156+
)
157+
)
158+
mask_indptr = None
159+
max_extend_len = torch.max(spec_info.accept_length).item()
160+
attn_logits = None
95161
else:
96162
kv_indptr[1 : bs + 1] = torch.cumsum(
97163
forward_batch.extend_prefix_lens, dim=0
98164
)
99165
kv_indptr = kv_indptr[: bs + 1]
100-
kv_indices = torch.empty(
166+
kv_indices = torch.zeros(
101167
forward_batch.extend_prefix_lens.sum().item(),
102168
dtype=torch.int32,
103169
device=self.device,
@@ -116,8 +182,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
116182
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
117183
qo_indptr = qo_indptr[: bs + 1]
118184
custom_mask = None
119-
mask_offsets = None
120-
185+
mask_indptr = None
121186
attn_logits = None
122187
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
123188

@@ -128,22 +193,22 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
128193
kv_indices,
129194
qo_indptr,
130195
custom_mask,
131-
mask_offsets,
196+
mask_indptr,
132197
)
133198

134199
def init_cuda_graph_state(self, max_bs: int):
135-
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
200+
self.cuda_graph_max_total_num_tokens = max_bs * self.max_context_len
136201

137202
self.cuda_graph_start_loc = torch.zeros(
138203
(max_bs,), dtype=torch.int32, device=self.device
139204
)
140-
self.cuda_graph_attn_logits = torch.empty(
205+
self.cuda_graph_attn_logits = torch.zeros(
141206
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
142207
dtype=torch.float32,
143208
device=self.device,
144209
)
145210
self.cuda_graph_kv_indices = torch.zeros(
146-
(max_bs * self.cuda_graph_max_seq_len),
211+
(max_bs * self.max_context_len),
147212
dtype=torch.int32,
148213
device=self.device,
149214
)
@@ -244,8 +309,9 @@ def forward_extend(
244309
kv_indices,
245310
qo_indptr,
246311
custom_mask,
247-
mask_offsets,
312+
mask_indptr,
248313
) = self.forward_metadata
314+
249315
self.extend_attention_fwd(
250316
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
251317
k.contiguous(),
@@ -257,7 +323,7 @@ def forward_extend(
257323
kv_indptr,
258324
kv_indices,
259325
custom_mask,
260-
mask_offsets,
326+
mask_indptr,
261327
max_extend_len,
262328
layer.scaling,
263329
layer.logit_cap,
@@ -303,3 +369,136 @@ def forward_decode(
303369
layer.logit_cap,
304370
)
305371
return o
372+
373+
374+
class TritonMultiStepDraftBackend:
375+
"""
376+
Wrap multiple triton attention backends as one for multiple consecutive
377+
draft decoding steps.
378+
"""
379+
380+
def __init__(
381+
self,
382+
model_runner: ModelRunner,
383+
topk: int,
384+
speculative_num_steps: int,
385+
):
386+
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
387+
388+
self.topk = topk
389+
self.speculative_num_steps = speculative_num_steps
390+
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
391+
max_bs = model_runner.req_to_token_pool.size
392+
self.kv_indptr = torch.zeros(
393+
(
394+
self.speculative_num_steps,
395+
max_bs + 1,
396+
),
397+
dtype=torch.int32,
398+
device=model_runner.device,
399+
)
400+
self.attn_backends = []
401+
for i in range(self.speculative_num_steps):
402+
self.attn_backends.append(
403+
TritonAttnBackend(
404+
model_runner,
405+
skip_prefill=True,
406+
kv_indptr_buf=self.kv_indptr[i],
407+
)
408+
)
409+
self.max_context_len = self.attn_backends[0].max_context_len
410+
# Cached variables for generate_draft_decode_kv_indices
411+
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
412+
413+
def common_template(
414+
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
415+
):
416+
num_seqs = forward_batch.batch_size
417+
bs = self.topk * num_seqs
418+
seq_lens_sum = forward_batch.seq_lens_sum
419+
420+
self.generate_draft_decode_kv_indices[
421+
(self.speculative_num_steps, num_seqs, self.topk)
422+
](
423+
forward_batch.req_pool_indices,
424+
forward_batch.req_to_token_pool.req_to_token,
425+
forward_batch.seq_lens,
426+
kv_indices_buffer,
427+
self.kv_indptr,
428+
forward_batch.positions,
429+
num_seqs,
430+
self.topk,
431+
self.pool_len,
432+
kv_indices_buffer.shape[1],
433+
self.kv_indptr.shape[1],
434+
triton.next_power_of_2(num_seqs),
435+
triton.next_power_of_2(self.speculative_num_steps),
436+
triton.next_power_of_2(bs),
437+
)
438+
439+
for i in range(self.speculative_num_steps):
440+
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
441+
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
442+
: seq_lens_sum * self.topk + bs * (i + 1)
443+
]
444+
call_fn(i, forward_batch)
445+
446+
def init_forward_metadata(self, forward_batch: ForwardBatch):
447+
kv_indices = torch.zeros(
448+
(
449+
self.speculative_num_steps,
450+
forward_batch.batch_size * self.topk * self.max_context_len,
451+
),
452+
dtype=torch.int32,
453+
device="cuda",
454+
)
455+
456+
def call_fn(i, forward_batch):
457+
forward_batch.spec_info.kv_indptr = (
458+
forward_batch.spec_info.kv_indptr.clone()
459+
)
460+
forward_batch.spec_info.kv_indices = (
461+
forward_batch.spec_info.kv_indices.clone()
462+
)
463+
self.attn_backends[i].init_forward_metadata(forward_batch)
464+
465+
self.common_template(forward_batch, kv_indices, call_fn)
466+
467+
def init_cuda_graph_state(self, max_bs: int):
468+
self.cuda_graph_kv_indices = torch.zeros(
469+
(self.speculative_num_steps, max_bs * self.max_context_len),
470+
dtype=torch.int32,
471+
device="cuda",
472+
)
473+
for i in range(self.speculative_num_steps):
474+
self.attn_backends[i].init_cuda_graph_state(
475+
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
476+
)
477+
478+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
479+
def call_fn(i, forward_batch):
480+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
481+
forward_batch.batch_size,
482+
forward_batch.batch_size * self.topk,
483+
forward_batch.req_pool_indices,
484+
forward_batch.seq_lens,
485+
encoder_lens=None,
486+
forward_mode=ForwardMode.DECODE,
487+
spec_info=forward_batch.spec_info,
488+
)
489+
490+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
491+
492+
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
493+
def call_fn(i, forward_batch):
494+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
495+
forward_batch.batch_size,
496+
forward_batch.req_pool_indices,
497+
forward_batch.seq_lens,
498+
seq_lens_sum=-1,
499+
encoder_lens=None,
500+
forward_mode=ForwardMode.DECODE,
501+
spec_info=forward_batch.spec_info,
502+
)
503+
504+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)

python/sglang/srt/layers/attention/triton_ops/extend_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _fwd_kernel(
5050
kv_indptr,
5151
kv_indices,
5252
mask_ptr,
53-
mask_offsets,
53+
mask_indptr,
5454
sm_scale,
5555
kv_group_num,
5656
stride_qbs,
@@ -87,7 +87,7 @@ def _fwd_kernel(
8787
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
8888

8989
if USE_CUSTOM_MASK:
90-
cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)
90+
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
9191

9292
offs_d = tl.arange(0, BLOCK_DMODEL)
9393
offs_dv = tl.arange(0, BLOCK_DV)
@@ -288,7 +288,7 @@ def extend_attention_fwd(
288288
kv_indptr,
289289
kv_indices,
290290
custom_mask,
291-
mask_offsets,
291+
mask_indptr,
292292
max_len_extend,
293293
sm_scale=None,
294294
logit_cap=0.0,
@@ -364,7 +364,7 @@ def extend_attention_fwd(
364364
kv_indptr,
365365
kv_indices,
366366
custom_mask,
367-
mask_offsets,
367+
mask_indptr,
368368
sm_scale,
369369
kv_group_num,
370370
q_extend.stride(0),

0 commit comments

Comments
 (0)