3
3
from typing import TYPE_CHECKING , Optional
4
4
5
5
import torch
6
+ import triton
6
7
7
8
from sglang .srt .layers .attention import AttentionBackend
8
9
from sglang .srt .layers .attention .flashinfer_backend import (
18
19
19
20
20
21
class TritonAttnBackend (AttentionBackend ):
21
- def __init__ (self , model_runner : ModelRunner ):
22
+ def __init__ (
23
+ self ,
24
+ model_runner : ModelRunner ,
25
+ skip_prefill : bool = False ,
26
+ kv_indptr_buf : Optional [torch .Tensor ] = None ,
27
+ ):
22
28
# Lazy import to avoid the initialization of cuda context
23
29
from sglang .srt .layers .attention .triton_ops .decode_attention import (
24
30
decode_attention_fwd ,
@@ -33,14 +39,25 @@ def __init__(self, model_runner: ModelRunner):
33
39
self .extend_attention_fwd = extend_attention_fwd
34
40
35
41
max_bs = model_runner .req_to_token_pool .size
36
- self .kv_indptr = torch .zeros (
37
- (max_bs + 1 ,), dtype = torch .int32 , device = model_runner .device
38
- )
42
+
43
+ if kv_indptr_buf is None :
44
+ self .kv_indptr = torch .zeros (
45
+ (max_bs + 1 ,), dtype = torch .int32 , device = model_runner .device
46
+ )
47
+ else :
48
+ self .kv_indptr = kv_indptr_buf
49
+
39
50
self .req_to_token = model_runner .req_to_token_pool .req_to_token
40
51
self .qo_indptr = torch .zeros (
41
52
(max_bs + 1 ,), dtype = torch .int32 , device = model_runner .device
42
53
)
43
54
55
+ self .mask_indptr = torch .zeros (
56
+ (max_bs + 1 ,), dtype = torch .int64 , device = model_runner .device
57
+ )
58
+
59
+ self .num_draft_tokens = model_runner .server_args .speculative_num_draft_tokens
60
+
44
61
self .num_head = (
45
62
model_runner .model_config .num_attention_heads // get_attention_tp_size ()
46
63
)
@@ -50,7 +67,7 @@ def __init__(self, model_runner: ModelRunner):
50
67
51
68
self .forward_metadata = None
52
69
53
- self .cuda_graph_max_seq_len = model_runner .model_config .context_len
70
+ self .max_context_len = model_runner .model_config .context_len
54
71
55
72
self .device = model_runner .device
56
73
@@ -59,11 +76,31 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
59
76
60
77
bs = forward_batch .batch_size
61
78
kv_indptr = self .kv_indptr
62
-
63
- if forward_batch .forward_mode .is_decode ():
64
- attn_logits = torch .empty (
79
+ spec_info = forward_batch .spec_info
80
+
81
+ if forward_batch .forward_mode .is_decode_or_idle ():
82
+ if spec_info is None :
83
+ kv_indptr [1 : bs + 1 ] = torch .cumsum (forward_batch .seq_lens , dim = 0 )
84
+ kv_indptr = kv_indptr [: bs + 1 ]
85
+ kv_indices = torch .zeros (
86
+ forward_batch .seq_lens_sum , dtype = torch .int32 , device = self .device
87
+ )
88
+ create_flashinfer_kv_indices_triton [(bs ,)](
89
+ self .req_to_token ,
90
+ forward_batch .req_pool_indices ,
91
+ forward_batch .seq_lens ,
92
+ kv_indptr ,
93
+ None ,
94
+ kv_indices ,
95
+ self .req_to_token .stride (0 ),
96
+ )
97
+ else :
98
+ kv_indptr , kv_indices = spec_info .kv_indptr , spec_info .kv_indices
99
+ bs = kv_indptr .shape [0 ] - 1
100
+
101
+ attn_logits = torch .zeros (
65
102
(
66
- forward_batch . batch_size ,
103
+ bs ,
67
104
self .num_head ,
68
105
self .num_kv_splits ,
69
106
self .v_head_dim + 1 ,
@@ -72,12 +109,24 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
72
109
device = self .device ,
73
110
)
74
111
112
+ qo_indptr = None
113
+ custom_mask = None
114
+ mask_indptr = None
75
115
max_extend_len = None
76
-
116
+ elif forward_batch .forward_mode .is_target_verify ():
117
+ bs = len (forward_batch .req_pool_indices )
118
+ qo_indptr = torch .arange (
119
+ 0 ,
120
+ (1 + bs ) * self .num_draft_tokens ,
121
+ step = self .num_draft_tokens ,
122
+ dtype = torch .int32 ,
123
+ device = self .device ,
124
+ )
125
+ # Different with flashinfer kv_indptr and kv_indices construction
77
126
kv_indptr [1 : bs + 1 ] = torch .cumsum (forward_batch .seq_lens , dim = 0 )
78
127
kv_indptr = kv_indptr [: bs + 1 ]
79
- kv_indices = torch .empty (
80
- forward_batch . seq_lens_sum , dtype = torch .int32 , device = self .device
128
+ kv_indices = torch .zeros (
129
+ kv_indptr [ - 1 ] , dtype = torch .int32 , device = self .device
81
130
)
82
131
create_flashinfer_kv_indices_triton [(bs ,)](
83
132
self .req_to_token ,
@@ -89,15 +138,32 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
89
138
self .req_to_token .stride (0 ),
90
139
)
91
140
92
- qo_indptr = None
93
- custom_mask = None
94
- mask_offsets = None
141
+ custom_mask = spec_info .custom_mask
142
+ seq_mask_len = self .num_draft_tokens * (
143
+ forward_batch .seq_lens + self .num_draft_tokens
144
+ )
145
+ mask_indptr = self .mask_indptr
146
+ mask_indptr [1 : bs + 1 ] = torch .cumsum (seq_mask_len [:bs ], dim = 0 )
147
+ mask_indptr = mask_indptr [: bs + 1 ]
148
+ max_extend_len = self .num_draft_tokens
149
+ attn_logits = None
150
+ elif forward_batch .forward_mode .is_draft_extend ():
151
+ kv_indices , kv_indptr , qo_indptr , custom_mask = (
152
+ spec_info .generate_attn_arg_prefill (
153
+ forward_batch .req_pool_indices ,
154
+ forward_batch .seq_lens ,
155
+ self .req_to_token ,
156
+ )
157
+ )
158
+ mask_indptr = None
159
+ max_extend_len = torch .max (spec_info .accept_length ).item ()
160
+ attn_logits = None
95
161
else :
96
162
kv_indptr [1 : bs + 1 ] = torch .cumsum (
97
163
forward_batch .extend_prefix_lens , dim = 0
98
164
)
99
165
kv_indptr = kv_indptr [: bs + 1 ]
100
- kv_indices = torch .empty (
166
+ kv_indices = torch .zeros (
101
167
forward_batch .extend_prefix_lens .sum ().item (),
102
168
dtype = torch .int32 ,
103
169
device = self .device ,
@@ -116,8 +182,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
116
182
qo_indptr [1 : bs + 1 ] = torch .cumsum (forward_batch .extend_seq_lens , dim = 0 )
117
183
qo_indptr = qo_indptr [: bs + 1 ]
118
184
custom_mask = None
119
- mask_offsets = None
120
-
185
+ mask_indptr = None
121
186
attn_logits = None
122
187
max_extend_len = torch .max (forward_batch .extend_seq_lens ).item ()
123
188
@@ -128,22 +193,22 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
128
193
kv_indices ,
129
194
qo_indptr ,
130
195
custom_mask ,
131
- mask_offsets ,
196
+ mask_indptr ,
132
197
)
133
198
134
199
def init_cuda_graph_state (self , max_bs : int ):
135
- self .cuda_graph_max_total_num_tokens = max_bs * self .cuda_graph_max_seq_len
200
+ self .cuda_graph_max_total_num_tokens = max_bs * self .max_context_len
136
201
137
202
self .cuda_graph_start_loc = torch .zeros (
138
203
(max_bs ,), dtype = torch .int32 , device = self .device
139
204
)
140
- self .cuda_graph_attn_logits = torch .empty (
205
+ self .cuda_graph_attn_logits = torch .zeros (
141
206
(max_bs , self .num_head , self .num_kv_splits , self .v_head_dim + 1 ),
142
207
dtype = torch .float32 ,
143
208
device = self .device ,
144
209
)
145
210
self .cuda_graph_kv_indices = torch .zeros (
146
- (max_bs * self .cuda_graph_max_seq_len ),
211
+ (max_bs * self .max_context_len ),
147
212
dtype = torch .int32 ,
148
213
device = self .device ,
149
214
)
@@ -244,8 +309,9 @@ def forward_extend(
244
309
kv_indices ,
245
310
qo_indptr ,
246
311
custom_mask ,
247
- mask_offsets ,
312
+ mask_indptr ,
248
313
) = self .forward_metadata
314
+
249
315
self .extend_attention_fwd (
250
316
q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
251
317
k .contiguous (),
@@ -257,7 +323,7 @@ def forward_extend(
257
323
kv_indptr ,
258
324
kv_indices ,
259
325
custom_mask ,
260
- mask_offsets ,
326
+ mask_indptr ,
261
327
max_extend_len ,
262
328
layer .scaling ,
263
329
layer .logit_cap ,
@@ -303,3 +369,136 @@ def forward_decode(
303
369
layer .logit_cap ,
304
370
)
305
371
return o
372
+
373
+
374
+ class TritonMultiStepDraftBackend :
375
+ """
376
+ Wrap multiple triton attention backends as one for multiple consecutive
377
+ draft decoding steps.
378
+ """
379
+
380
+ def __init__ (
381
+ self ,
382
+ model_runner : ModelRunner ,
383
+ topk : int ,
384
+ speculative_num_steps : int ,
385
+ ):
386
+ from sglang .srt .speculative .eagle_utils import generate_draft_decode_kv_indices
387
+
388
+ self .topk = topk
389
+ self .speculative_num_steps = speculative_num_steps
390
+ self .generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
391
+ max_bs = model_runner .req_to_token_pool .size
392
+ self .kv_indptr = torch .zeros (
393
+ (
394
+ self .speculative_num_steps ,
395
+ max_bs + 1 ,
396
+ ),
397
+ dtype = torch .int32 ,
398
+ device = model_runner .device ,
399
+ )
400
+ self .attn_backends = []
401
+ for i in range (self .speculative_num_steps ):
402
+ self .attn_backends .append (
403
+ TritonAttnBackend (
404
+ model_runner ,
405
+ skip_prefill = True ,
406
+ kv_indptr_buf = self .kv_indptr [i ],
407
+ )
408
+ )
409
+ self .max_context_len = self .attn_backends [0 ].max_context_len
410
+ # Cached variables for generate_draft_decode_kv_indices
411
+ self .pool_len = model_runner .req_to_token_pool .req_to_token .shape [1 ]
412
+
413
+ def common_template (
414
+ self , forward_batch : ForwardBatch , kv_indices_buffer : torch .Tensor , call_fn : int
415
+ ):
416
+ num_seqs = forward_batch .batch_size
417
+ bs = self .topk * num_seqs
418
+ seq_lens_sum = forward_batch .seq_lens_sum
419
+
420
+ self .generate_draft_decode_kv_indices [
421
+ (self .speculative_num_steps , num_seqs , self .topk )
422
+ ](
423
+ forward_batch .req_pool_indices ,
424
+ forward_batch .req_to_token_pool .req_to_token ,
425
+ forward_batch .seq_lens ,
426
+ kv_indices_buffer ,
427
+ self .kv_indptr ,
428
+ forward_batch .positions ,
429
+ num_seqs ,
430
+ self .topk ,
431
+ self .pool_len ,
432
+ kv_indices_buffer .shape [1 ],
433
+ self .kv_indptr .shape [1 ],
434
+ triton .next_power_of_2 (num_seqs ),
435
+ triton .next_power_of_2 (self .speculative_num_steps ),
436
+ triton .next_power_of_2 (bs ),
437
+ )
438
+
439
+ for i in range (self .speculative_num_steps ):
440
+ forward_batch .spec_info .kv_indptr = self .kv_indptr [i , : bs + 1 ]
441
+ forward_batch .spec_info .kv_indices = kv_indices_buffer [i ][
442
+ : seq_lens_sum * self .topk + bs * (i + 1 )
443
+ ]
444
+ call_fn (i , forward_batch )
445
+
446
+ def init_forward_metadata (self , forward_batch : ForwardBatch ):
447
+ kv_indices = torch .zeros (
448
+ (
449
+ self .speculative_num_steps ,
450
+ forward_batch .batch_size * self .topk * self .max_context_len ,
451
+ ),
452
+ dtype = torch .int32 ,
453
+ device = "cuda" ,
454
+ )
455
+
456
+ def call_fn (i , forward_batch ):
457
+ forward_batch .spec_info .kv_indptr = (
458
+ forward_batch .spec_info .kv_indptr .clone ()
459
+ )
460
+ forward_batch .spec_info .kv_indices = (
461
+ forward_batch .spec_info .kv_indices .clone ()
462
+ )
463
+ self .attn_backends [i ].init_forward_metadata (forward_batch )
464
+
465
+ self .common_template (forward_batch , kv_indices , call_fn )
466
+
467
+ def init_cuda_graph_state (self , max_bs : int ):
468
+ self .cuda_graph_kv_indices = torch .zeros (
469
+ (self .speculative_num_steps , max_bs * self .max_context_len ),
470
+ dtype = torch .int32 ,
471
+ device = "cuda" ,
472
+ )
473
+ for i in range (self .speculative_num_steps ):
474
+ self .attn_backends [i ].init_cuda_graph_state (
475
+ max_bs , kv_indices_buf = self .cuda_graph_kv_indices [i ]
476
+ )
477
+
478
+ def init_forward_metadata_capture_cuda_graph (self , forward_batch : ForwardBatch ):
479
+ def call_fn (i , forward_batch ):
480
+ self .attn_backends [i ].init_forward_metadata_capture_cuda_graph (
481
+ forward_batch .batch_size ,
482
+ forward_batch .batch_size * self .topk ,
483
+ forward_batch .req_pool_indices ,
484
+ forward_batch .seq_lens ,
485
+ encoder_lens = None ,
486
+ forward_mode = ForwardMode .DECODE ,
487
+ spec_info = forward_batch .spec_info ,
488
+ )
489
+
490
+ self .common_template (forward_batch , self .cuda_graph_kv_indices , call_fn )
491
+
492
+ def init_forward_metadata_replay_cuda_graph (self , forward_batch ):
493
+ def call_fn (i , forward_batch ):
494
+ self .attn_backends [i ].init_forward_metadata_replay_cuda_graph (
495
+ forward_batch .batch_size ,
496
+ forward_batch .req_pool_indices ,
497
+ forward_batch .seq_lens ,
498
+ seq_lens_sum = - 1 ,
499
+ encoder_lens = None ,
500
+ forward_mode = ForwardMode .DECODE ,
501
+ spec_info = forward_batch .spec_info ,
502
+ )
503
+
504
+ self .common_template (forward_batch , self .cuda_graph_kv_indices , call_fn )
0 commit comments