1919import torch
2020
2121from vllm import _custom_ops as ops
22- from vllm .model_executor . layers . attention import Attention
22+ from vllm .attention . layer import Attention
2323from vllm .config import VllmConfig , get_layers_from_vllm_config
2424from vllm .distributed .kv_events import BlockRemoved , BlockStored , KVCacheEvent
2525from vllm .distributed .kv_transfer .kv_connector .utils import yield_req_data
@@ -511,9 +511,6 @@ def __init__(self, spec: OffloadingSpec):
511511 self ._gpu_tensors : dict [str , torch .Tensor ] = {}
512512 self ._cpu_tensors : dict [str , torch .Tensor ] = {}
513513 self ._kv_dim_before_num_blocks : dict [str , bool ] = {}
514- self ._block_size_bytes : dict [str , int ] = {}
515- # True when all layers share a single GPU tensor
516- self ._single_tensor_mode : bool = False
517514
518515 # CUDA streams for per-layer transfers
519516 self ._load_stream : torch .cuda .Stream | None = None
@@ -530,16 +527,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
530527 Create per-layer CPU tensors and register handlers for bulk transfers.
531528 """
532529 layer_names = list (kv_caches .keys ())
533-
534- # Detect single-tensor mode: all layers share the same GPU tensor
535- gpu_ptrs = {t .data_ptr () for t in kv_caches .values ()}
536- self ._single_tensor_mode = (len (gpu_ptrs ) == 1 and len (layer_names ) > 1 )
537- if self ._single_tensor_mode :
538- logger .info (
539- "Single-tensor swap mode: %d layers share 1 GPU KV buffer" ,
540- len (layer_names ),
541- )
542-
543530 layers = get_layers_from_vllm_config (
544531 self .spec .vllm_config , Attention , layer_names
545532 )
@@ -548,15 +535,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
548535 for layer_name in layer_names
549536 }
550537
551- # Register handlers for bulk transfers (prefix cache loads/stores).
552- # Skip in single-tensor mode: bulk handlers assume per-layer GPU
553- # tensors and won't work with a shared buffer. The per-layer
554- # load/store cycle handles all transfers instead.
555- if not self ._single_tensor_mode :
556- for src_cls , dst_cls , handler in self .spec .get_handlers (
557- kv_caches , attn_backends
558- ):
559- self .worker .register_handler (src_cls , dst_cls , handler )
538+ # Register handlers for bulk transfers (prefix cache loads/stores)
539+ for src_cls , dst_cls , handler in self .spec .get_handlers (
540+ kv_caches , attn_backends
541+ ):
542+ self .worker .register_handler (src_cls , dst_cls , handler )
560543
561544 pin_memory = is_pin_memory_available ()
562545 num_cpu_blocks = self .spec .num_blocks
@@ -597,24 +580,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
597580 layer_name ,
598581 cpu_shape ,
599582 )
600- cpu_tensor = torch .zeros (
583+ self . _cpu_tensors [ layer_name ] = torch .zeros (
601584 cpu_shape ,
602585 dtype = gpu_tensor .dtype ,
603586 device = "cpu" ,
604587 pin_memory = pin_memory ,
605588 )
606- self ._cpu_tensors [layer_name ] = cpu_tensor
607-
608- # Compute block size in bytes for ops.swap_blocks (new 4-arg API).
609- # When kv_dim=True, shape is (2, num_blocks, ...) and we swap
610- # each K/V half separately via tensor[0]/tensor[1].
611- if self ._kv_dim_before_num_blocks [layer_name ]:
612- ref = cpu_tensor [0 ]
613- else :
614- ref = cpu_tensor
615- self ._block_size_bytes [layer_name ] = (
616- ref .element_size () * ref .stride (0 )
617- )
618589
619590 # Summary log after all CPU tensors are registered
620591 total_cpu_bytes = sum (
@@ -640,9 +611,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
640611 )
641612
642613 def handle_preemptions (self , preempted_req_ids : set [str ]):
643- if self ._single_tensor_mode :
644- return
645-
646614 for job_id , transfer_spec in self ._unsubmitted_store_jobs :
647615 success = self .worker .transfer_async (job_id , transfer_spec )
648616 assert success
@@ -655,11 +623,6 @@ def handle_preemptions(self, preempted_req_ids: set[str]):
655623
656624 def start_kv_transfers (self , metadata : SwapConnectorMetadata ):
657625 """Submit deferred stores and start prefix cache loads."""
658- # In single-tensor mode, skip bulk transfers entirely.
659- # Per-layer load/store handles everything.
660- if self ._single_tensor_mode :
661- return
662-
663626 # Submit deferred store jobs from the previous step
664627 for job_id , transfer_spec in self ._unsubmitted_store_jobs :
665628 success = self .worker .transfer_async (job_id , transfer_spec )
@@ -740,14 +703,13 @@ def load_layer_from_cpu(
740703 cpu_tensor = self ._cpu_tensors [layer_name ]
741704 gpu_tensor = self ._gpu_tensors [layer_name ]
742705 kv_dim = self ._kv_dim_before_num_blocks [layer_name ]
743- block_size_bytes = self ._block_size_bytes [layer_name ]
744706
745707 with torch .cuda .stream (self ._load_stream ):
746708 if kv_dim :
747- ops .swap_blocks (cpu_tensor [0 ], gpu_tensor [0 ], block_size_bytes , src_to_dst_tensor )
748- ops .swap_blocks (cpu_tensor [1 ], gpu_tensor [1 ], block_size_bytes , src_to_dst_tensor )
709+ ops .swap_blocks (cpu_tensor [0 ], gpu_tensor [0 ], src_to_dst_tensor )
710+ ops .swap_blocks (cpu_tensor [1 ], gpu_tensor [1 ], src_to_dst_tensor )
749711 else :
750- ops .swap_blocks (cpu_tensor , gpu_tensor , block_size_bytes , src_to_dst_tensor )
712+ ops .swap_blocks (cpu_tensor , gpu_tensor , src_to_dst_tensor )
751713
752714 # Must synchronize: attention needs the data to be ready
753715 self ._load_stream .synchronize ()
@@ -814,14 +776,13 @@ def store_layer_to_cpu(
814776 cpu_tensor = self ._cpu_tensors [layer_name ]
815777 gpu_tensor = self ._gpu_tensors [layer_name ]
816778 kv_dim = self ._kv_dim_before_num_blocks [layer_name ]
817- block_size_bytes = self ._block_size_bytes [layer_name ]
818779
819780 with torch .cuda .stream (self ._store_stream ):
820781 if kv_dim :
821- ops .swap_blocks (gpu_tensor [0 ], cpu_tensor [0 ], block_size_bytes , src_to_dst_tensor )
822- ops .swap_blocks (gpu_tensor [1 ], cpu_tensor [1 ], block_size_bytes , src_to_dst_tensor )
782+ ops .swap_blocks (gpu_tensor [0 ], cpu_tensor [0 ], src_to_dst_tensor )
783+ ops .swap_blocks (gpu_tensor [1 ], cpu_tensor [1 ], src_to_dst_tensor )
823784 else :
824- ops .swap_blocks (gpu_tensor , cpu_tensor , block_size_bytes , src_to_dst_tensor )
785+ ops .swap_blocks (gpu_tensor , cpu_tensor , src_to_dst_tensor )
825786 # Record event for the load stream to wait on
826787 if self ._store_event is None :
827788 self ._store_event = torch .Event ()
@@ -836,9 +797,6 @@ def wait_for_all_stores(self):
836797
837798 def prepare_store_kv (self , metadata : SwapConnectorMetadata ):
838799 """Prepare bulk store jobs for the scheduler's reqs_to_store."""
839- if self ._single_tensor_mode :
840- return
841-
842800 for req_id , transfer_spec in metadata .reqs_to_store .items ():
843801 job_id = self ._generate_job_id ()
844802 self ._jobs [job_id ] = (req_id , True )
@@ -848,15 +806,10 @@ def prepare_store_kv(self, metadata: SwapConnectorMetadata):
848806 def get_finished (
849807 self , finished_req_ids : set [str ]
850808 ) -> tuple [set [str ], set [str ]]:
851- # In single-tensor mode, no bulk jobs are submitted
852- if self ._single_tensor_mode :
853- return set (), set ()
854-
855809 finished_sending = set ()
856810 finished_recving = set ()
857- for result in self .worker .get_finished ():
858- assert result .success
859- job_id = result .job_id
811+ for job_id , success in self .worker .get_finished ():
812+ assert success
860813 req_id , store = self ._jobs .pop (job_id )
861814 if store :
862815 req_jobs = self ._store_jobs [req_id ]
0 commit comments