Skip to content

Commit a87d3be

Browse files
NickLuccheyeqcharlotte
authored andcommitted
[V1][Kernel] Flashinfer HND KV cache layout (vllm-project#19280)
Signed-off-by: NickLucche <[email protected]>
1 parent 9490bd1 commit a87d3be

File tree

6 files changed

+64
-20
lines changed

6 files changed

+64
-20
lines changed

vllm/attention/backends/flashinfer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import dataclasses
5-
import os
65
from collections import defaultdict
76
from contextlib import contextmanager
87
from dataclasses import dataclass
@@ -50,8 +49,7 @@
5049
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
5150
ModelInputForGPUWithSamplingMetadata)
5251

53-
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
54-
"NHD").upper()
52+
FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD"
5553

5654

5755
class FlashInferBackend(AttentionBackend):

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
KV cache helper for store.
55
"""
6-
76
import torch
87

98
import vllm.envs as envs
@@ -94,15 +93,17 @@ def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
9493

9594

9695
def get_kv_connector_cache_layout():
96+
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
97+
# used for faster transfer.
9798
vllm_config = get_current_vllm_config()
9899
kv_config = vllm_config.kv_transfer_config
99-
if vllm_config.model_config is None:
100-
logger.warning("Unable to detect current VLLM config. " \
100+
if vllm_config.model_config is None or kv_config is None:
101+
logger.warning_once("Unable to detect current VLLM config. " \
101102
"Defaulting to NHD kv cache layout.")
102103
else:
103104
use_mla = vllm_config.model_config.use_mla
104105
if not use_mla and kv_config.kv_connector == "NixlConnector":
105-
logger.info("NixlConnector detected. Setting KV cache " \
106+
logger.info_once("NixlConnector detected. Setting KV cache " \
106107
"layout to HND for better xfer performance.")
107108
return "HND"
108109
return "NHD"

vllm/envs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
130130
VLLM_SLEEP_WHEN_IDLE: bool = False
131131
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
132+
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
132133

133134

134135
def get_default_cache_root():
@@ -886,6 +887,16 @@ def get_vllm_port() -> Optional[int]:
886887
# processes via zmq.
887888
"VLLM_MQ_MAX_CHUNK_BYTES_MB":
888889
lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")),
890+
891+
# KV Cache layout used throughout vllm.
892+
# Some common values are:
893+
# - NHD
894+
# - HND
895+
# Where N=num_blocks, H=num_heads and D=head_size. The default value will
896+
# leave the layout choice to the backend. Mind that backends may only
897+
# implement and support a subset of all possible layouts.
898+
"VLLM_KV_CACHE_LAYOUT":
899+
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None)
889900
}
890901

891902
# --8<-- [end:env-vars-definition]

vllm/v1/attention/backends/flash_attn.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
1717
get_flash_attn_version)
1818
from vllm.config import VllmConfig, get_layers_from_vllm_config
19-
from vllm.distributed.kv_transfer.kv_connector.utils import (
20-
get_kv_connector_cache_layout)
2119
from vllm.logger import init_logger
2220
from vllm.platforms import current_platform
2321
from vllm.utils import cdiv
2422
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
25-
CommonAttentionMetadata)
23+
CommonAttentionMetadata,
24+
get_kv_cache_layout)
2625
from vllm.v1.kv_cache_interface import AttentionSpec
2726
from vllm.v1.worker.block_table import BlockTable
2827

@@ -73,16 +72,15 @@ def get_kv_cache_shape(
7372

7473
@staticmethod
7574
def get_kv_cache_stride_order() -> tuple[int, ...]:
76-
# NOTE When running disaggregated PD with NIXL, HND layout is used for
77-
# faster transfer. `stride_order` indicates the permutation that gets
75+
# `stride_order` indicates the permutation that gets
7876
# us from `get_kv_cache_shape` to the actual memory layout we want.
79-
cache_layout = get_kv_connector_cache_layout()
77+
cache_layout = get_kv_cache_layout()
8078
if cache_layout == "NHD":
8179
stride_order = (0, 1, 2, 3, 4)
8280
elif cache_layout == "HND":
8381
stride_order = (0, 1, 3, 2, 4)
8482
else:
85-
raise ValueError("Unknown cache layout format %s.", cache_layout)
83+
raise ValueError(f"Unknown cache layout format {cache_layout}.")
8684
return stride_order
8785

8886

vllm/v1/attention/backends/flashinfer.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from vllm.logger import init_logger
2020
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2121
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
22-
CommonAttentionMetadata)
22+
CommonAttentionMetadata,
23+
get_kv_cache_layout)
2324
from vllm.v1.kv_cache_interface import AttentionSpec
2425
from vllm.v1.worker.block_table import BlockTable
2526

@@ -66,6 +67,19 @@ def get_kv_cache_shape(
6667
) -> tuple[int, ...]:
6768
return (num_blocks, 2, block_size, num_kv_heads, head_size)
6869

70+
@staticmethod
71+
def get_kv_cache_stride_order() -> tuple[int, ...]:
72+
# `stride_order` indicates the permutation that gets us from
73+
# `get_kv_cache_shape` to the actual memory layout we want.
74+
cache_layout = get_kv_cache_layout()
75+
if cache_layout == "NHD":
76+
stride_order = (0, 1, 2, 3, 4)
77+
elif cache_layout == "HND":
78+
stride_order = (0, 1, 3, 2, 4)
79+
else:
80+
raise ValueError(f"Unknown cache layout format {cache_layout}.")
81+
return stride_order
82+
6983

7084
@dataclass
7185
class PerLayerParameters:
@@ -290,7 +304,7 @@ def _get_workspace_buffer(self):
290304
def _get_prefill_wrapper(self):
291305
if self._prefill_wrapper is None:
292306
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
293-
self._get_workspace_buffer(), "NHD")
307+
self._get_workspace_buffer(), get_kv_cache_layout())
294308
return self._prefill_wrapper
295309

296310
def _get_decode_wrapper(self):
@@ -303,14 +317,14 @@ def _get_decode_wrapper(self):
303317
num_qo_heads // num_kv_heads > 4)
304318
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
305319
self._get_workspace_buffer(),
306-
"NHD",
320+
get_kv_cache_layout(),
307321
use_tensor_cores=use_tensor_cores)
308322
return self._decode_wrapper
309323

310324
def _get_cascade_wrapper(self):
311325
if self._cascade_wrapper is None:
312326
self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
313-
2, self._get_workspace_buffer(), "NHD")
327+
2, self._get_workspace_buffer(), get_kv_cache_layout())
314328
return self._cascade_wrapper
315329

316330
def _plan(self, attn_metadata: FlashInferMetadata):
@@ -620,6 +634,7 @@ def forward(
620634
num_decode_tokens = attn_metadata.num_decode_tokens
621635
num_prefill_tokens = attn_metadata.num_prefill_tokens
622636

637+
stride_order = FlashInferBackend.get_kv_cache_stride_order()
623638
# Regular attention (common case).
624639
# Decodes are at the front and prefills are at the back,
625640
# according to reorder_batch()
@@ -634,7 +649,7 @@ def forward(
634649
assert prefill_wrapper._sm_scale == self.scale
635650
prefill_wrapper.run(
636651
prefill_query,
637-
kv_cache,
652+
kv_cache.permute(*stride_order),
638653
k_scale=layer._k_scale_float,
639654
v_scale=layer._v_scale_float,
640655
out=output[num_decode_tokens:],
@@ -650,7 +665,7 @@ def forward(
650665
assert decode_wrapper._sm_scale == self.scale
651666
decode_wrapper.run(
652667
decode_query,
653-
kv_cache,
668+
kv_cache.permute(*stride_order),
654669
k_scale=layer._k_scale_float,
655670
v_scale=layer._v_scale_float,
656671
out=output[:num_decode_tokens],

vllm/v1/attention/backends/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import abc
4+
import functools
45
from abc import abstractmethod
56
from dataclasses import dataclass
67
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
@@ -12,6 +13,13 @@
1213
from vllm.v1.core.sched.output import SchedulerOutput
1314
from vllm.v1.worker.gpu_input_batch import InputBatch
1415

16+
import vllm.envs as envs
17+
from vllm.distributed.kv_transfer.kv_connector.utils import (
18+
get_kv_connector_cache_layout)
19+
from vllm.logger import init_logger
20+
21+
logger = init_logger(__name__)
22+
1523

1624
@dataclass
1725
class CommonAttentionMetadata:
@@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name,
119127
raise ValueError(
120128
error_msg +
121129
f"must be the same type as the current layer ({expected}).")
130+
131+
132+
@functools.lru_cache
133+
def get_kv_cache_layout():
134+
# Override with format specified by the user.
135+
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
136+
if cache_layout is None:
137+
cache_layout = get_kv_connector_cache_layout()
138+
else:
139+
logger.info_once("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \
140+
"detected. Setting KV cache layout to %s.", cache_layout)
141+
142+
return cache_layout

0 commit comments

Comments
 (0)