@@ -45,6 +45,9 @@ def __init__(
45
45
self ,
46
46
model_runner : ModelRunner ,
47
47
skip_prefill : bool = False ,
48
+ topk = 0 ,
49
+ speculative_num_steps = 0 ,
50
+ step_id = 0 ,
48
51
):
49
52
super ().__init__ ()
50
53
@@ -63,6 +66,10 @@ def __init__(
63
66
self .use_mla = (
64
67
model_runner .model_config .attention_arch == AttentionArch .MLA
65
68
) and (not global_server_args_dict ["disable_mla" ])
69
+ self .skip_prefill = skip_prefill
70
+ self .topk = topk
71
+ self .speculative_num_steps = speculative_num_steps
72
+ self .step_id = step_id
66
73
67
74
def init_forward_metadata (self , forward_batch : ForwardBatch ):
68
75
"""Initialize forward metadata to cache repetitive calculations."""
@@ -72,37 +79,125 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
72
79
# Get sequence information
73
80
seqlens_in_batch = forward_batch .seq_lens
74
81
# Precompute int32 version of sequence lengths
75
- metadata .cache_seqlens_int32 = seqlens_in_batch .to (torch .int32 )
76
82
batch_size = len (seqlens_in_batch )
77
83
device = seqlens_in_batch .device
78
- metadata .cu_seqlens_k = torch .nn .functional .pad (
79
- torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 )
80
- )
81
- # Precompute maximum sequence length
82
- metadata .max_seq_len_k = forward_batch .seq_lens_cpu .max ().item ()
83
- # Precompute page table
84
- metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
85
- forward_batch .req_pool_indices , : metadata .max_seq_len_k
86
- ]
87
-
88
- # Precompute strided indices
89
- # [0, page_size, 2 * page_size, ...]
90
- if self .page_size > 1 :
91
- self .strided_indices = torch .arange (
92
- 0 , metadata .page_table .shape [1 ], self .page_size , device = self .device
93
- )
94
- metadata .page_table = (
95
- metadata .page_table [:, self .strided_indices ] // self .page_size
96
- )
97
84
98
85
if forward_batch .forward_mode == ForwardMode .DECODE :
99
- # Precompute cumulative sequence lengths
86
+ if self .skip_prefill :
87
+ metadata .cu_seqlens_q = torch .arange (
88
+ 0 , batch_size * self .topk + 1 , dtype = torch .int32 , device = device
89
+ )
90
+ seq_lens_with_decode = seqlens_in_batch + (self .step_id + 1 )
91
+ metadata .cache_seqlens_int32 = (
92
+ (seq_lens_with_decode ).repeat_interleave (self .topk ).to (torch .int32 )
93
+ )
94
+ metadata .cu_seqlens_k = torch .nn .functional .pad (
95
+ torch .cumsum (
96
+ metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
97
+ ),
98
+ (1 , 0 ),
99
+ )
100
+ metadata .max_seq_len_k = forward_batch .seq_lens_cpu .max ().item () + (
101
+ self .step_id + 1
102
+ )
103
+ metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
104
+ forward_batch .req_pool_indices , : metadata .max_seq_len_k
105
+ ]
106
+ metadata .page_table = metadata .page_table .repeat_interleave (
107
+ self .topk , dim = 0
108
+ )
109
+ cache_loc = forward_batch .out_cache_loc .view (
110
+ self .speculative_num_steps , - 1
111
+ ).T
112
+ # Calculate page table indices and cache location indices to update the page table.
113
+ batch_indices = torch .arange (
114
+ batch_size , device = device
115
+ ).repeat_interleave (self .topk * (self .step_id + 1 ))
116
+ topk_indices = torch .arange (self .topk , device = device ).repeat (
117
+ batch_size * (self .step_id + 1 )
118
+ )
119
+ row_indices = batch_indices * self .topk + topk_indices
120
+
121
+ page_table_col_base_indices = seqlens_in_batch .unsqueeze (
122
+ 1
123
+ ) + torch .arange (self .step_id + 1 , device = device )
124
+ page_table_col_indices = page_table_col_base_indices .view (- 1 ).repeat (
125
+ self .topk
126
+ )
127
+
128
+ cache_loc_col_indices = torch .arange (
129
+ self .step_id + 1 , device = device , dtype = torch .int32
130
+ ).repeat (batch_size * self .topk )
131
+
132
+ metadata .page_table [row_indices , page_table_col_indices ] = cache_loc [
133
+ row_indices , cache_loc_col_indices
134
+ ].to (torch .int32 )
135
+ else :
136
+ metadata .cache_seqlens_int32 = seqlens_in_batch .to (torch .int32 )
137
+ metadata .cu_seqlens_k = torch .nn .functional .pad (
138
+ torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 )
139
+ )
140
+ # Precompute maximum sequence length
141
+ metadata .max_seq_len_k = forward_batch .seq_lens_cpu .max ().item ()
142
+ # Precompute page table
143
+ metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
144
+ forward_batch .req_pool_indices , : metadata .max_seq_len_k
145
+ ]
146
+ metadata .cu_seqlens_q = torch .arange (
147
+ 0 , batch_size + 1 , dtype = torch .int32 , device = device
148
+ )
149
+ elif forward_batch .forward_mode == ForwardMode .TARGET_VERIFY :
150
+ draft_token_num = forward_batch .spec_info .draft_token_num
151
+
100
152
metadata .cu_seqlens_q = torch .arange (
101
- 0 , batch_size + 1 , dtype = torch .int32 , device = device
153
+ 0 , batch_size * draft_token_num + 1 , dtype = torch .int32 , device = device
154
+ )
155
+
156
+ aug_seq_lens = (forward_batch .seq_lens + draft_token_num ).to (torch .int32 )
157
+ metadata .cache_seqlens_int32 = aug_seq_lens .repeat_interleave (
158
+ forward_batch .spec_info .draft_token_num
159
+ )
160
+ metadata .cu_seqlens_k = torch .nn .functional .pad (
161
+ torch .cumsum (metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32 ),
162
+ (1 , 0 ),
102
163
)
164
+ metadata .max_seq_len_k = (
165
+ forward_batch .seq_lens_cpu .max ().item () + draft_token_num
166
+ )
167
+ metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
168
+ forward_batch .req_pool_indices , : metadata .max_seq_len_k
169
+ ].repeat_interleave (draft_token_num , dim = 0 )
170
+ aug_cum_len = torch .nn .functional .pad (
171
+ torch .cumsum (aug_seq_lens , dim = 0 , dtype = torch .int32 ), (1 , 0 )
172
+ )
173
+ for idx , single_seq_len in enumerate (aug_seq_lens ):
174
+ metadata .page_table [
175
+ idx * draft_token_num : (idx + 1 ) * draft_token_num , :single_seq_len
176
+ ] *= forward_batch .spec_info .custom_mask [
177
+ aug_cum_len [idx ]
178
+ * draft_token_num : aug_cum_len [idx + 1 ]
179
+ * draft_token_num
180
+ ].view (
181
+ draft_token_num , - 1
182
+ )
183
+
184
+ metadata .max_seq_len_q = 1
103
185
else :
186
+ metadata .cache_seqlens_int32 = seqlens_in_batch .to (torch .int32 )
187
+ metadata .cu_seqlens_k = torch .nn .functional .pad (
188
+ torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 )
189
+ )
190
+ # Precompute maximum sequence length
191
+ metadata .max_seq_len_k = forward_batch .seq_lens_cpu .max ().item ()
192
+ # Precompute page table
193
+ metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
194
+ forward_batch .req_pool_indices , : metadata .max_seq_len_k
195
+ ]
104
196
# Precompute cumulative sequence lengths
105
- if any (forward_batch .extend_prefix_lens_cpu ):
197
+ if (
198
+ any (forward_batch .extend_prefix_lens_cpu )
199
+ or forward_batch .forward_mode == ForwardMode .DRAFT_EXTEND
200
+ ):
106
201
extend_seq_lens = forward_batch .extend_seq_lens
107
202
metadata .cu_seqlens_q = torch .nn .functional .pad (
108
203
torch .cumsum (extend_seq_lens , dim = 0 , dtype = torch .int32 ), (1 , 0 )
@@ -111,6 +206,16 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
111
206
else :
112
207
metadata .cu_seqlens_q = metadata .cu_seqlens_k
113
208
metadata .max_seq_len_q = metadata .max_seq_len_k
209
+
210
+ # Precompute strided indices
211
+ # [0, page_size, 2 * page_size, ...]
212
+ if self .page_size > 1 :
213
+ self .strided_indices = torch .arange (
214
+ 0 , metadata .page_table .shape [1 ], self .page_size , device = self .device
215
+ )
216
+ metadata .page_table = (
217
+ metadata .page_table [:, self .strided_indices ] // self .page_size
218
+ )
114
219
self .forward_metadata = metadata
115
220
116
221
def forward_extend (
@@ -281,8 +386,6 @@ def forward_decode(
281
386
282
387
# Pre-reshape query tensor
283
388
q_reshaped = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim )
284
-
285
- # Run attention with precomputed values
286
389
o = flash_attn_with_kvcache (
287
390
q = q_reshaped ,
288
391
k_cache = key_cache ,
@@ -346,7 +449,11 @@ def init_cuda_graph_state(self, max_bs: int):
346
449
This creates fixed-size tensors that will be reused during CUDA graph replay
347
450
to avoid memory allocations.
348
451
"""
349
- # Initialize fixed size tensors for decode operations
452
+ if self .speculative_num_steps > 0 :
453
+ raise NotImplementedError (
454
+ "FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!"
455
+ )
456
+
350
457
self .decode_cuda_graph_metadata = {
351
458
# Page table for token mapping (batch_size, max_context_len)
352
459
"page_table" : torch .zeros (
@@ -385,7 +492,7 @@ def init_forward_metadata_capture_cuda_graph(
385
492
metadata .page_table = self .decode_cuda_graph_metadata ["page_table" ][
386
493
req_pool_indices , :
387
494
]
388
- if forward_mode == ForwardMode . DECODE :
495
+ if forward_mode . is_cuda_graph () :
389
496
# Precompute cumulative sequence lengths
390
497
metadata .cu_seqlens_q = torch .arange (
391
498
0 , batch_size + 1 , dtype = torch .int32 , device = device
@@ -432,3 +539,66 @@ def init_forward_metadata_replay_cuda_graph(
432
539
def get_cuda_graph_seq_len_fill_value (self ):
433
540
"""Get the fill value for sequence length in CUDA graph."""
434
541
return 0
542
+
543
+
544
+ class FlashAttentionMultiStepBackend :
545
+
546
+ def __init__ (
547
+ self , model_runner : ModelRunner , topk : int , speculative_num_steps : int
548
+ ):
549
+ self .model_runner = model_runner
550
+ self .topk = topk
551
+ self .speculative_num_steps = speculative_num_steps
552
+
553
+ self .attn_backends = []
554
+ for i in range (self .speculative_num_steps ):
555
+ self .attn_backends .append (
556
+ FlashAttentionBackend (
557
+ model_runner ,
558
+ skip_prefill = True ,
559
+ topk = self .topk ,
560
+ speculative_num_steps = self .speculative_num_steps ,
561
+ step_id = i ,
562
+ )
563
+ )
564
+
565
+ def init_forward_metadata (self , forward_batch : ForwardBatch ):
566
+ for i in range (self .speculative_num_steps - 1 ):
567
+ self .attn_backends [i ].init_forward_metadata (forward_batch )
568
+
569
+ def init_cuda_graph_state (self , max_bs : int ):
570
+ for i in range (self .speculative_num_steps ):
571
+ self .attn_backends [i ].init_cuda_graph_state (max_bs )
572
+
573
+ def init_forward_metadata_capture_cuda_graph (self , forward_batch : ForwardBatch ):
574
+ assert forward_batch .spec_info is not None
575
+ assert isinstance (forward_batch .spec_info , EagleDraftInput )
576
+
577
+ for i in range (self .speculative_num_steps - 1 ):
578
+ self .attn_backends [i ].init_forward_metadata_capture_cuda_graph (
579
+ forward_batch .batch_size ,
580
+ forward_batch .batch_size * self .topk ,
581
+ forward_batch .req_pool_indices ,
582
+ forward_batch .seq_lens ,
583
+ encoder_lens = None ,
584
+ forward_mode = ForwardMode .DECODE ,
585
+ spec_info = forward_batch .spec_info ,
586
+ )
587
+
588
+ def init_forward_metadata_replay_cuda_graph (
589
+ self , forward_batch : ForwardBatch , bs : int
590
+ ):
591
+ assert forward_batch .spec_info is not None
592
+ assert isinstance (forward_batch .spec_info , EagleDraftInput )
593
+
594
+ for i in range (self .speculative_num_steps - 1 ):
595
+ self .attn_backends [i ].init_forward_metadata_replay_cuda_graph (
596
+ bs ,
597
+ forward_batch .req_pool_indices ,
598
+ forward_batch .seq_lens ,
599
+ forward_batch .seq_lens_sum ,
600
+ encoder_lens = None ,
601
+ forward_mode = ForwardMode .DECODE ,
602
+ spec_info = forward_batch .spec_info ,
603
+ seq_lens_cpu = forward_batch .seq_lens_cpu ,
604
+ )
0 commit comments