From c0c6d8c9c4edbfcb3b7cac5e5674eb6b4809e565 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 2 May 2025 05:03:16 +0000 Subject: [PATCH] fix sliding window v1 Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flash_attn.py | 36 +++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 217dcd7c33a..f986d797f2b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,9 +10,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) +from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv @@ -273,13 +275,23 @@ def make_local_attention_virtual_batches( block_table_local +def _get_sliding_window_configs( + vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: + """Get the set of all sliding window configs used in the model.""" + sliding_window_configs: set[Optional[tuple[int, int]]] = set() + layers = get_layers_from_vllm_config(vllm_config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, FlashAttentionImpl) + sliding_window_configs.add(layer.impl.sliding_window) + return sliding_window_configs + + class FlashAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): model_config = runner.model_config self.runner = runner - self.aot_schedule = (get_flash_attn_version() == 3) self.num_heads_q = model_config.get_num_attention_heads( runner.parallel_config) self.num_heads_kv = model_config.get_num_kv_heads( @@ -287,6 +299,11 @@ def __init__(self, runner: "GPUModelRunner"): self.headdim = model_config.get_head_size() self.page_size = self.runner.block_size + self.aot_schedule = (get_flash_attn_version() == 3) + # Sliding window size to be used with the AOT scheduler will be + # populated on first build() call. + self.aot_sliding_window: Optional[tuple[int, int]] = None + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False @@ -304,6 +321,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if self.aot_schedule: + sliding_window_configs = _get_sliding_window_configs( + self.runner.vllm_config) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.aot_schedule: @@ -318,6 +351,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, page_size=self.page_size, cu_seqlens_q=cu_query_lens, causal=causal, + window_size=self.aot_sliding_window, ) return None