Skip to content

Commit 394c0c4

Browse files
committed
Revert "swap impl draft"
This reverts commit d43dc76.
1 parent 6439307 commit 394c0c4

File tree

6 files changed

+73
-162
lines changed

6 files changed

+73
-162
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/swap_connector.py

Lines changed: 15 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020

2121
from vllm import _custom_ops as ops
22-
from vllm.model_executor.layers.attention import Attention
22+
from vllm.attention.layer import Attention
2323
from vllm.config import VllmConfig, get_layers_from_vllm_config
2424
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
2525
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
@@ -511,9 +511,6 @@ def __init__(self, spec: OffloadingSpec):
511511
self._gpu_tensors: dict[str, torch.Tensor] = {}
512512
self._cpu_tensors: dict[str, torch.Tensor] = {}
513513
self._kv_dim_before_num_blocks: dict[str, bool] = {}
514-
self._block_size_bytes: dict[str, int] = {}
515-
# True when all layers share a single GPU tensor
516-
self._single_tensor_mode: bool = False
517514

518515
# CUDA streams for per-layer transfers
519516
self._load_stream: torch.cuda.Stream | None = None
@@ -530,16 +527,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
530527
Create per-layer CPU tensors and register handlers for bulk transfers.
531528
"""
532529
layer_names = list(kv_caches.keys())
533-
534-
# Detect single-tensor mode: all layers share the same GPU tensor
535-
gpu_ptrs = {t.data_ptr() for t in kv_caches.values()}
536-
self._single_tensor_mode = (len(gpu_ptrs) == 1 and len(layer_names) > 1)
537-
if self._single_tensor_mode:
538-
logger.info(
539-
"Single-tensor swap mode: %d layers share 1 GPU KV buffer",
540-
len(layer_names),
541-
)
542-
543530
layers = get_layers_from_vllm_config(
544531
self.spec.vllm_config, Attention, layer_names
545532
)
@@ -548,15 +535,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
548535
for layer_name in layer_names
549536
}
550537

551-
# Register handlers for bulk transfers (prefix cache loads/stores).
552-
# Skip in single-tensor mode: bulk handlers assume per-layer GPU
553-
# tensors and won't work with a shared buffer. The per-layer
554-
# load/store cycle handles all transfers instead.
555-
if not self._single_tensor_mode:
556-
for src_cls, dst_cls, handler in self.spec.get_handlers(
557-
kv_caches, attn_backends
558-
):
559-
self.worker.register_handler(src_cls, dst_cls, handler)
538+
# Register handlers for bulk transfers (prefix cache loads/stores)
539+
for src_cls, dst_cls, handler in self.spec.get_handlers(
540+
kv_caches, attn_backends
541+
):
542+
self.worker.register_handler(src_cls, dst_cls, handler)
560543

561544
pin_memory = is_pin_memory_available()
562545
num_cpu_blocks = self.spec.num_blocks
@@ -597,24 +580,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
597580
layer_name,
598581
cpu_shape,
599582
)
600-
cpu_tensor = torch.zeros(
583+
self._cpu_tensors[layer_name] = torch.zeros(
601584
cpu_shape,
602585
dtype=gpu_tensor.dtype,
603586
device="cpu",
604587
pin_memory=pin_memory,
605588
)
606-
self._cpu_tensors[layer_name] = cpu_tensor
607-
608-
# Compute block size in bytes for ops.swap_blocks (new 4-arg API).
609-
# When kv_dim=True, shape is (2, num_blocks, ...) and we swap
610-
# each K/V half separately via tensor[0]/tensor[1].
611-
if self._kv_dim_before_num_blocks[layer_name]:
612-
ref = cpu_tensor[0]
613-
else:
614-
ref = cpu_tensor
615-
self._block_size_bytes[layer_name] = (
616-
ref.element_size() * ref.stride(0)
617-
)
618589

619590
# Summary log after all CPU tensors are registered
620591
total_cpu_bytes = sum(
@@ -640,9 +611,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
640611
)
641612

642613
def handle_preemptions(self, preempted_req_ids: set[str]):
643-
if self._single_tensor_mode:
644-
return
645-
646614
for job_id, transfer_spec in self._unsubmitted_store_jobs:
647615
success = self.worker.transfer_async(job_id, transfer_spec)
648616
assert success
@@ -655,11 +623,6 @@ def handle_preemptions(self, preempted_req_ids: set[str]):
655623

656624
def start_kv_transfers(self, metadata: SwapConnectorMetadata):
657625
"""Submit deferred stores and start prefix cache loads."""
658-
# In single-tensor mode, skip bulk transfers entirely.
659-
# Per-layer load/store handles everything.
660-
if self._single_tensor_mode:
661-
return
662-
663626
# Submit deferred store jobs from the previous step
664627
for job_id, transfer_spec in self._unsubmitted_store_jobs:
665628
success = self.worker.transfer_async(job_id, transfer_spec)
@@ -740,14 +703,13 @@ def load_layer_from_cpu(
740703
cpu_tensor = self._cpu_tensors[layer_name]
741704
gpu_tensor = self._gpu_tensors[layer_name]
742705
kv_dim = self._kv_dim_before_num_blocks[layer_name]
743-
block_size_bytes = self._block_size_bytes[layer_name]
744706

745707
with torch.cuda.stream(self._load_stream):
746708
if kv_dim:
747-
ops.swap_blocks(cpu_tensor[0], gpu_tensor[0], block_size_bytes, src_to_dst_tensor)
748-
ops.swap_blocks(cpu_tensor[1], gpu_tensor[1], block_size_bytes, src_to_dst_tensor)
709+
ops.swap_blocks(cpu_tensor[0], gpu_tensor[0], src_to_dst_tensor)
710+
ops.swap_blocks(cpu_tensor[1], gpu_tensor[1], src_to_dst_tensor)
749711
else:
750-
ops.swap_blocks(cpu_tensor, gpu_tensor, block_size_bytes, src_to_dst_tensor)
712+
ops.swap_blocks(cpu_tensor, gpu_tensor, src_to_dst_tensor)
751713

752714
# Must synchronize: attention needs the data to be ready
753715
self._load_stream.synchronize()
@@ -814,14 +776,13 @@ def store_layer_to_cpu(
814776
cpu_tensor = self._cpu_tensors[layer_name]
815777
gpu_tensor = self._gpu_tensors[layer_name]
816778
kv_dim = self._kv_dim_before_num_blocks[layer_name]
817-
block_size_bytes = self._block_size_bytes[layer_name]
818779

819780
with torch.cuda.stream(self._store_stream):
820781
if kv_dim:
821-
ops.swap_blocks(gpu_tensor[0], cpu_tensor[0], block_size_bytes, src_to_dst_tensor)
822-
ops.swap_blocks(gpu_tensor[1], cpu_tensor[1], block_size_bytes, src_to_dst_tensor)
782+
ops.swap_blocks(gpu_tensor[0], cpu_tensor[0], src_to_dst_tensor)
783+
ops.swap_blocks(gpu_tensor[1], cpu_tensor[1], src_to_dst_tensor)
823784
else:
824-
ops.swap_blocks(gpu_tensor, cpu_tensor, block_size_bytes, src_to_dst_tensor)
785+
ops.swap_blocks(gpu_tensor, cpu_tensor, src_to_dst_tensor)
825786
# Record event for the load stream to wait on
826787
if self._store_event is None:
827788
self._store_event = torch.Event()
@@ -836,9 +797,6 @@ def wait_for_all_stores(self):
836797

837798
def prepare_store_kv(self, metadata: SwapConnectorMetadata):
838799
"""Prepare bulk store jobs for the scheduler's reqs_to_store."""
839-
if self._single_tensor_mode:
840-
return
841-
842800
for req_id, transfer_spec in metadata.reqs_to_store.items():
843801
job_id = self._generate_job_id()
844802
self._jobs[job_id] = (req_id, True)
@@ -848,15 +806,10 @@ def prepare_store_kv(self, metadata: SwapConnectorMetadata):
848806
def get_finished(
849807
self, finished_req_ids: set[str]
850808
) -> tuple[set[str], set[str]]:
851-
# In single-tensor mode, no bulk jobs are submitted
852-
if self._single_tensor_mode:
853-
return set(), set()
854-
855809
finished_sending = set()
856810
finished_recving = set()
857-
for result in self.worker.get_finished():
858-
assert result.success
859-
job_id = result.job_id
811+
for job_id, success in self.worker.get_finished():
812+
assert success
860813
req_id, store = self._jobs.pop(job_id)
861814
if store:
862815
req_jobs = self._store_jobs[req_id]

vllm/v1/core/kv_cache_utils.py

Lines changed: 1 addition & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,95 +1145,12 @@ def get_kv_cache_config_from_groups(
11451145
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
11461146
)
11471147

1148-
config = KVCacheConfig(
1148+
return KVCacheConfig(
11491149
num_blocks=num_blocks,
11501150
kv_cache_tensors=kv_cache_tensors,
11511151
kv_cache_groups=kv_cache_groups,
11521152
)
11531153

1154-
# In single-GPU-tensor swap mode, all layers share 1 GPU KV cache tensor.
1155-
# num_blocks is determined by CPU capacity, not GPU.
1156-
config = _maybe_apply_single_gpu_tensor(vllm_config, config,
1157-
available_memory)
1158-
1159-
return config
1160-
1161-
1162-
def _maybe_apply_single_gpu_tensor(
1163-
vllm_config: VllmConfig,
1164-
config: KVCacheConfig,
1165-
available_memory: int,
1166-
) -> KVCacheConfig:
1167-
"""
1168-
If single_gpu_tensor swap mode is enabled, replace the N per-layer
1169-
KVCacheTensors with a single shared tensor. num_blocks is capped by
1170-
both CPU capacity and available GPU memory.
1171-
"""
1172-
kv_transfer_config = vllm_config.kv_transfer_config
1173-
if kv_transfer_config is None:
1174-
return config
1175-
extra = kv_transfer_config.kv_connector_extra_config or {}
1176-
if not extra.get("single_gpu_tensor", False):
1177-
return config
1178-
1179-
cpu_bytes_to_use = extra.get("cpu_bytes_to_use")
1180-
if not cpu_bytes_to_use:
1181-
raise ValueError(
1182-
"cpu_bytes_to_use must be set when single_gpu_tensor is enabled"
1183-
)
1184-
cpu_bytes_to_use = int(cpu_bytes_to_use)
1185-
1186-
# Collect all layer names and compute page size
1187-
all_layer_names: list[str] = []
1188-
for group in config.kv_cache_groups:
1189-
all_layer_names.extend(group.layer_names)
1190-
num_layers = len(all_layer_names)
1191-
assert num_layers > 0
1192-
1193-
# All tensors should have the same page size
1194-
page_sizes = set()
1195-
for group in config.kv_cache_groups:
1196-
page_sizes.add(group.kv_cache_spec.page_size_bytes)
1197-
assert len(page_sizes) == 1, (
1198-
"single_gpu_tensor mode requires uniform page sizes across groups"
1199-
)
1200-
page_size = page_sizes.pop()
1201-
1202-
# num_blocks determined by CPU capacity, capped by GPU memory.
1203-
# CPU holds all layers' data; GPU holds just 1 layer's worth.
1204-
cpu_num_blocks = int(cpu_bytes_to_use // (page_size * num_layers))
1205-
gpu_num_blocks = int(available_memory // page_size)
1206-
num_blocks = min(cpu_num_blocks, gpu_num_blocks)
1207-
assert num_blocks > 0, (
1208-
f"Cannot allocate any KV blocks. "
1209-
f"cpu_bytes_to_use={cpu_bytes_to_use} -> {cpu_num_blocks} blocks, "
1210-
f"available_gpu_memory={available_memory} -> {gpu_num_blocks} blocks, "
1211-
f"num_layers={num_layers}, page_size={page_size}"
1212-
)
1213-
1214-
bottleneck = "GPU" if gpu_num_blocks < cpu_num_blocks else "CPU"
1215-
logger.info(
1216-
"Single-GPU-tensor swap mode: %d blocks (limited by %s), "
1217-
"%d layers sharing 1 GPU tensor of %s, CPU backing %s",
1218-
num_blocks,
1219-
bottleneck,
1220-
num_layers,
1221-
format_gib(page_size * num_blocks),
1222-
format_gib(cpu_bytes_to_use),
1223-
)
1224-
1225-
# One GPU tensor shared by ALL layers
1226-
single_tensor = KVCacheTensor(
1227-
size=page_size * num_blocks,
1228-
shared_by=all_layer_names,
1229-
)
1230-
1231-
return KVCacheConfig(
1232-
num_blocks=num_blocks,
1233-
kv_cache_tensors=[single_tensor],
1234-
kv_cache_groups=config.kv_cache_groups,
1235-
)
1236-
12371154

12381155
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
12391156
"""

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def schedule(self) -> SchedulerOutput:
939939
)
940940

941941
with record_function_or_nullcontext(nvtx_label):
942-
self._update_after_schedule(scheduler_output)
942+
self._update_after_schedule(scheduler_output)
943943
return scheduler_output
944944

945945
def _preempt_request(self, request: Request, timestamp: float) -> None:

vllm/v1/kv_offload/cpu_swap.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,9 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):
5252
}
5353
assert len(page_sizes) == 1
5454
page_size_bytes = page_sizes.pop()
55-
56-
# Use actual layer count from groups (not len(kv_cache_tensors),
57-
# which is 1 in single-GPU-tensor mode).
58-
num_layers = sum(
59-
len(g.layer_names)
60-
for g in kv_cache_config.kv_cache_groups
61-
)
6255
kv_bytes_per_block = (
6356
page_size_bytes
64-
* num_layers
57+
* len(kv_cache_config.kv_cache_tensors)
6558
* vllm_config.parallel_config.world_size
6659
)
6760
kv_bytes_per_offloaded_block = kv_bytes_per_block * (
@@ -74,9 +67,8 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):
7467
else 0
7568
)
7669

77-
# In single-GPU-tensor mode, all layers share one GPU tensor,
78-
# so total_gpu_blocks is just the num_blocks from the config.
79-
# In normal mode, sum across tensors.
70+
# Calculate the total number of GPU blocks that could be allocated.
71+
# The CPU must be able to hold all of them.
8072
total_gpu_blocks = sum(
8173
kv_cache_tensor.size // page_size_bytes
8274
for kv_cache_tensor in kv_cache_config.kv_cache_tensors
@@ -87,9 +79,6 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):
8779
total_gpu_blocks + block_size_factor - 1
8880
) // block_size_factor
8981

90-
# In single-GPU-tensor mode, the GPU tensor is small (1 layer)
91-
# and the CPU holds all layers' data. The assertion is naturally
92-
# satisfied since num_blocks was computed from cpu_bytes_to_use.
9382
assert self.num_blocks >= required_offloaded_blocks, (
9483
f"CPU swap mode requires enough CPU memory to hold all KV cache. "
9584
f"CPU can hold {self.num_blocks} offloaded blocks but "

vllm/v1/worker/gpu/attn_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,58 @@ def build_slot_mappings_by_layer(
167167
return slot_mappings_by_layer
168168

169169

170+
def init_kv_cache_with_offloading(
171+
runner_kv_caches: list[torch.Tensor],
172+
forward_context: dict[str, Any],
173+
kv_cache_config: KVCacheConfig,
174+
attn_backends: dict[str, AttentionBackend],
175+
gpu_device: torch.device,
176+
num_gpu_buffer_layers: int = 2,
177+
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
178+
"""Initialize KV cache with per-layer CPU offloading.
179+
180+
Returns:
181+
A tuple of (gpu_kv_caches, cpu_kv_caches), each mapping
182+
layer_name -> reshaped KV cache tensor.
183+
"""
184+
gpu_raw, cpu_raw = _allocate_kv_cache_with_offloading(
185+
kv_cache_config, gpu_device, num_gpu_buffer_layers,
186+
)
187+
gpu_kv_caches = _reshape_kv_cache(kv_cache_config, gpu_raw, attn_backends)
188+
cpu_kv_caches = _reshape_kv_cache(kv_cache_config, cpu_raw, attn_backends)
189+
190+
# Bind the GPU buffer tensors so attention layers can use them
191+
bind_kv_cache(gpu_kv_caches, forward_context, runner_kv_caches)
192+
193+
return gpu_kv_caches, cpu_kv_caches
194+
195+
196+
def init_kv_cache_with_offloading(
197+
runner_kv_caches: list[torch.Tensor],
198+
forward_context: dict[str, Any],
199+
kv_cache_config: KVCacheConfig,
200+
attn_backends: dict[str, AttentionBackend],
201+
gpu_device: torch.device,
202+
num_gpu_buffer_layers: int = 2,
203+
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
204+
"""Initialize KV cache with per-layer CPU offloading.
205+
206+
Returns:
207+
A tuple of (gpu_kv_caches, cpu_kv_caches), each mapping
208+
layer_name -> reshaped KV cache tensor.
209+
"""
210+
gpu_raw, cpu_raw = _allocate_kv_cache_with_offloading(
211+
kv_cache_config, gpu_device, num_gpu_buffer_layers,
212+
)
213+
gpu_kv_caches = _reshape_kv_cache(kv_cache_config, gpu_raw, attn_backends)
214+
cpu_kv_caches = _reshape_kv_cache(kv_cache_config, cpu_raw, attn_backends)
215+
216+
# Bind the GPU buffer tensors so attention layers can use them
217+
bind_kv_cache(gpu_kv_caches, forward_context, runner_kv_caches)
218+
219+
return gpu_kv_caches, cpu_kv_caches
220+
221+
170222
def build_attn_metadata(
171223
attn_groups: list[list[AttentionGroup]],
172224
num_reqs: int,

vllm_profile/vllm_profile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main():
5858
profiler_config = None
5959

6060
ktc = KVTransferConfig(
61-
kv_connector="SwapConnector",
61+
kv_connector="LMCacheConnectorV1",
6262
kv_role="kv_both",
6363
kv_connector_extra_config={
6464
"cpu_bytes_to_use": 64 * 1024 ** 3, # 64 GiB of pinned CPU RAM for KV cache

0 commit comments

Comments
 (0)