@@ -59,6 +59,7 @@ def __init__(
59
59
self .device = model_runner .device
60
60
self .decode_cuda_graph_metadata = {}
61
61
self .req_to_token = model_runner .req_to_token_pool .req_to_token
62
+ self .page_size = model_runner .page_size
62
63
self .use_mla = (
63
64
model_runner .model_config .attention_arch == AttentionArch .MLA
64
65
) and (not global_server_args_dict ["disable_mla" ])
@@ -83,6 +84,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
83
84
metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
84
85
forward_batch .req_pool_indices , : metadata .max_seq_len_k
85
86
]
87
+
88
+ # Precompute strided indices
89
+ # [0, page_size, 2 * page_size, ...]
90
+ if self .page_size > 1 :
91
+ self .strided_indices = torch .arange (
92
+ 0 , metadata .page_table .shape [1 ], self .page_size , device = self .device
93
+ )
94
+ metadata .page_table = (
95
+ metadata .page_table [:, self .strided_indices ] // self .page_size
96
+ )
97
+
86
98
if forward_batch .forward_mode == ForwardMode .DECODE :
87
99
# Precompute cumulative sequence lengths
88
100
metadata .cu_seqlens_q = torch .arange (
@@ -141,16 +153,24 @@ def forward_extend(
141
153
else (- 1 , - 1 )
142
154
)
143
155
156
+ page_table = metadata .page_table
157
+
144
158
# # Use Flash Attention for prefill
145
159
if not self .use_mla :
146
160
# Do multi-head attention
147
161
kv_cache = forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id )
148
162
key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
163
+ key_cache = key_cache .view (
164
+ - 1 , self .page_size , layer .tp_k_head_num , layer .head_dim
165
+ )
166
+ value_cache = value_cache .view (
167
+ - 1 , self .page_size , layer .tp_v_head_num , layer .head_dim
168
+ )
149
169
o = flash_attn_with_kvcache (
150
170
q = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
151
- k_cache = key_cache . unsqueeze ( 1 ) ,
152
- v_cache = value_cache . unsqueeze ( 1 ) ,
153
- page_table = metadata . page_table ,
171
+ k_cache = key_cache ,
172
+ v_cache = value_cache ,
173
+ page_table = page_table ,
154
174
cache_seqlens = metadata .cache_seqlens_int32 ,
155
175
cu_seqlens_q = metadata .cu_seqlens_q ,
156
176
cu_seqlens_k_new = metadata .cu_seqlens_k ,
@@ -176,7 +196,7 @@ def forward_extend(
176
196
k_cache = k_rope .unsqueeze (1 ),
177
197
v_cache = c_kv .unsqueeze (1 ),
178
198
qv = q_nope ,
179
- page_table = metadata . page_table ,
199
+ page_table = page_table ,
180
200
cache_seqlens = metadata .cache_seqlens_int32 ,
181
201
cu_seqlens_q = metadata .cu_seqlens_q ,
182
202
cu_seqlens_k_new = metadata .cu_seqlens_k ,
@@ -231,22 +251,30 @@ def forward_decode(
231
251
else (- 1 , - 1 )
232
252
)
233
253
254
+ page_table = metadata .page_table
255
+
234
256
if not self .use_mla :
235
257
# Do multi-head attention
236
258
237
259
# Get KV cache
238
260
kv_cache = forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id )
239
261
key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
262
+ key_cache = key_cache .view (
263
+ - 1 , self .page_size , layer .tp_k_head_num , layer .head_dim
264
+ )
265
+ value_cache = value_cache .view (
266
+ - 1 , self .page_size , layer .tp_v_head_num , layer .head_dim
267
+ )
240
268
241
269
# Pre-reshape query tensor
242
270
q_reshaped = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim )
243
271
244
272
# Run attention with precomputed values
245
273
o = flash_attn_with_kvcache (
246
274
q = q_reshaped ,
247
- k_cache = key_cache . unsqueeze ( 1 ) ,
248
- v_cache = value_cache . unsqueeze ( 1 ) ,
249
- page_table = metadata . page_table ,
275
+ k_cache = key_cache ,
276
+ v_cache = value_cache ,
277
+ page_table = page_table ,
250
278
cache_seqlens = metadata .cache_seqlens_int32 ,
251
279
cu_seqlens_q = metadata .cu_seqlens_q ,
252
280
cu_seqlens_k_new = metadata .cu_seqlens_k ,
@@ -273,7 +301,7 @@ def forward_decode(
273
301
k_cache = k_rope .unsqueeze (1 ),
274
302
v_cache = c_kv .unsqueeze (1 ),
275
303
qv = q_nope ,
276
- page_table = metadata . page_table ,
304
+ page_table = page_table ,
277
305
cache_seqlens = metadata .cache_seqlens_int32 ,
278
306
cu_seqlens_q = metadata .cu_seqlens_q ,
279
307
cu_seqlens_k_new = metadata .cu_seqlens_k ,
@@ -300,7 +328,13 @@ def init_cuda_graph_state(self, max_bs: int):
300
328
self .decode_cuda_graph_metadata = {
301
329
# Page table for token mapping (batch_size, max_context_len)
302
330
"page_table" : torch .zeros (
303
- max_bs , self .max_context_len , dtype = torch .int32 , device = self .device
331
+ max_bs ,
332
+ (self .max_context_len + self .page_size - 1 ) // self .page_size ,
333
+ dtype = torch .int32 ,
334
+ device = self .device ,
335
+ ),
336
+ "strided_indices" : torch .arange (
337
+ 0 , self .max_context_len , self .page_size , device = self .device
304
338
),
305
339
}
306
340
0 commit comments