Skip to content

Commit 232d26b

Browse files
LucasWilkinsonmawong-amd
authored andcommitted
[BugFix][Attention] Fix sliding window attention in V1 giving incorrect results (vllm-project#17574)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 3da9787 commit 232d26b

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1111
AttentionMetadata, AttentionType,
1212
is_quantized_kv_cache)
13+
from vllm.attention.layer import Attention
1314
from vllm.attention.ops.merge_attn_states import merge_attn_states
1415
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
1516
get_flash_attn_version)
17+
from vllm.config import VllmConfig, get_layers_from_vllm_config
1618
from vllm.logger import init_logger
1719
from vllm.platforms import current_platform
1820
from vllm.utils import cdiv
@@ -273,20 +275,35 @@ def make_local_attention_virtual_batches(
273275
block_table_local
274276

275277

278+
def _get_sliding_window_configs(
279+
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
280+
"""Get the set of all sliding window configs used in the model."""
281+
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
282+
layers = get_layers_from_vllm_config(vllm_config, Attention)
283+
for layer in layers.values():
284+
assert isinstance(layer.impl, FlashAttentionImpl)
285+
sliding_window_configs.add(layer.impl.sliding_window)
286+
return sliding_window_configs
287+
288+
276289
class FlashAttentionMetadataBuilder:
277290

278291
def __init__(self, runner: "GPUModelRunner"):
279292
model_config = runner.model_config
280293

281294
self.runner = runner
282-
self.aot_schedule = (get_flash_attn_version() == 3)
283295
self.num_heads_q = model_config.get_num_attention_heads(
284296
runner.parallel_config)
285297
self.num_heads_kv = model_config.get_num_kv_heads(
286298
runner.parallel_config)
287299
self.headdim = model_config.get_head_size()
288300
self.page_size = self.runner.block_size
289301

302+
self.aot_schedule = (get_flash_attn_version() == 3)
303+
# Sliding window size to be used with the AOT scheduler will be
304+
# populated on first build() call.
305+
self.aot_sliding_window: Optional[tuple[int, int]] = None
306+
290307
def reorder_batch(self, input_batch: "InputBatch",
291308
scheduler_output: "SchedulerOutput") -> bool:
292309
return False
@@ -304,6 +321,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
304321
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
305322
self.runner.device, non_blocking=True).long()
306323

324+
if self.aot_sliding_window is None:
325+
self.aot_sliding_window = (-1, -1)
326+
# For the AOT scheduler we need the sliding window value to be
327+
# constant for all layers to. We have to populate this on the first
328+
# build() call so the layers are constructed (cannot populate)
329+
# in __init__.
330+
if self.aot_schedule:
331+
sliding_window_configs = _get_sliding_window_configs(
332+
self.runner.vllm_config)
333+
if len(sliding_window_configs) == 1:
334+
sliding_window_config = sliding_window_configs.pop()
335+
if sliding_window_config is not None:
336+
self.aot_sliding_window = sliding_window_config
337+
elif len(sliding_window_configs) > 1:
338+
self.aot_schedule = False
339+
307340
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
308341
max_seq_len, causal):
309342
if self.aot_schedule:
@@ -318,6 +351,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
318351
page_size=self.page_size,
319352
cu_seqlens_q=cu_query_lens,
320353
causal=causal,
354+
window_size=self.aot_sliding_window,
321355
)
322356
return None
323357

0 commit comments

Comments
 (0)