10
10
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
11
11
AttentionMetadata , AttentionType ,
12
12
is_quantized_kv_cache )
13
+ from vllm .attention .layer import Attention
13
14
from vllm .attention .ops .merge_attn_states import merge_attn_states
14
15
from vllm .attention .utils .fa_utils import (flash_attn_supports_fp8 ,
15
16
get_flash_attn_version )
17
+ from vllm .config import VllmConfig , get_layers_from_vllm_config
16
18
from vllm .logger import init_logger
17
19
from vllm .platforms import current_platform
18
20
from vllm .utils import cdiv
@@ -273,20 +275,35 @@ def make_local_attention_virtual_batches(
273
275
block_table_local
274
276
275
277
278
+ def _get_sliding_window_configs (
279
+ vllm_config : VllmConfig ) -> set [Optional [tuple [int , int ]]]:
280
+ """Get the set of all sliding window configs used in the model."""
281
+ sliding_window_configs : set [Optional [tuple [int , int ]]] = set ()
282
+ layers = get_layers_from_vllm_config (vllm_config , Attention )
283
+ for layer in layers .values ():
284
+ assert isinstance (layer .impl , FlashAttentionImpl )
285
+ sliding_window_configs .add (layer .impl .sliding_window )
286
+ return sliding_window_configs
287
+
288
+
276
289
class FlashAttentionMetadataBuilder :
277
290
278
291
def __init__ (self , runner : "GPUModelRunner" ):
279
292
model_config = runner .model_config
280
293
281
294
self .runner = runner
282
- self .aot_schedule = (get_flash_attn_version () == 3 )
283
295
self .num_heads_q = model_config .get_num_attention_heads (
284
296
runner .parallel_config )
285
297
self .num_heads_kv = model_config .get_num_kv_heads (
286
298
runner .parallel_config )
287
299
self .headdim = model_config .get_head_size ()
288
300
self .page_size = self .runner .block_size
289
301
302
+ self .aot_schedule = (get_flash_attn_version () == 3 )
303
+ # Sliding window size to be used with the AOT scheduler will be
304
+ # populated on first build() call.
305
+ self .aot_sliding_window : Optional [tuple [int , int ]] = None
306
+
290
307
def reorder_batch (self , input_batch : "InputBatch" ,
291
308
scheduler_output : "SchedulerOutput" ) -> bool :
292
309
return False
@@ -304,6 +321,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
304
321
slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
305
322
self .runner .device , non_blocking = True ).long ()
306
323
324
+ if self .aot_sliding_window is None :
325
+ self .aot_sliding_window = (- 1 , - 1 )
326
+ # For the AOT scheduler we need the sliding window value to be
327
+ # constant for all layers to. We have to populate this on the first
328
+ # build() call so the layers are constructed (cannot populate)
329
+ # in __init__.
330
+ if self .aot_schedule :
331
+ sliding_window_configs = _get_sliding_window_configs (
332
+ self .runner .vllm_config )
333
+ if len (sliding_window_configs ) == 1 :
334
+ sliding_window_config = sliding_window_configs .pop ()
335
+ if sliding_window_config is not None :
336
+ self .aot_sliding_window = sliding_window_config
337
+ elif len (sliding_window_configs ) > 1 :
338
+ self .aot_schedule = False
339
+
307
340
def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
308
341
max_seq_len , causal ):
309
342
if self .aot_schedule :
@@ -318,6 +351,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
318
351
page_size = self .page_size ,
319
352
cu_seqlens_q = cu_query_lens ,
320
353
causal = causal ,
354
+ window_size = self .aot_sliding_window ,
321
355
)
322
356
return None
323
357
0 commit comments