diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 80dd275a90b..669c0423692 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -751,3 +751,25 @@ def test_reset_prefix_cache(): assert manager.reset_prefix_cache() assert not manager.block_pool.cached_block_hash_to_block assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) + + +def test_prefix_cache_stats_disabled(): + """Test that prefix_cache_stats is None when log_stats is False.""" + manager = KVCacheManager( + make_kv_cache_config(16, 11), + max_model_len=8192, + enable_caching=True, + log_stats=False, # Disable logging stats + ) + assert manager.prefix_cache_stats is None + + # Call all functions that check whether log_stats is disabled. + req = make_request("0", list(range(16))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + assert not computed_blocks + assert num_computed_tokens == 0 + manager.allocate_slots(req, 16, computed_blocks) + manager.reset_prefix_cache() + + # Ensure prefix_cache_stats remains None + assert manager.prefix_cache_stats is None diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 4e74c20d366..ee2dd0cb14d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -39,8 +39,9 @@ def __init__( self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash - # FIXME: make prefix cache stats conditional on log_stats self.log_stats = log_stats + # FIXME: make prefix cache stats conditional on log_stats + self.prefix_cache_stats = PrefixCacheStats() if log_stats else None # NOTE(woosuk): To avoid frequent block allocation, we preallocate some # blocks for each request. For example, when a request reaches the end # of its block table, we preallocate N blocks in advance. This way, we @@ -79,7 +80,6 @@ def __init__( # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. self.num_cached_block: dict[str, int] = {} - self.prefix_cache_stats = PrefixCacheStats() @property def usage(self) -> float: @@ -90,12 +90,14 @@ def usage(self) -> float: """ return self.block_pool.get_usage() - def make_prefix_cache_stats(self) -> PrefixCacheStats: + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: """Get (and reset) the prefix cache stats. Returns: - The current prefix caching stats. + The current prefix caching stats, or None if logging is disabled. """ + if not self.log_stats: + return None stats = self.prefix_cache_stats self.prefix_cache_stats = PrefixCacheStats() return stats @@ -125,7 +127,9 @@ def get_computed_blocks( self.block_size, request) self.req_to_block_hashes[request.request_id] = block_hashes - self.prefix_cache_stats.requests += 1 + if self.log_stats: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.requests += 1 # When the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: return [], 0 @@ -145,8 +149,10 @@ def get_computed_blocks( computed_blocks = ( self.specialized_manager.find_longest_cache_hit(block_hashes)) - self.prefix_cache_stats.queries += len(block_hashes) - self.prefix_cache_stats.hits += len(computed_blocks) + if self.log_stats: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) if last_block_hash is not None: # Add back the last block hash if it was removed. @@ -308,17 +314,19 @@ def free(self, request: Request) -> None: def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, + flows to invalidate prefix caching after the weights are updated, or used for resetting prefix caching status for benchmarking. Returns: bool: True if the prefix cache is successfully reset, False otherwise. """ - if self.block_pool.reset_prefix_cache(): + if not self.block_pool.reset_prefix_cache(): + return False + if self.log_stats: + assert self.prefix_cache_stats is not None self.prefix_cache_stats.reset = True - return True - return False + return True def get_num_common_prefix_blocks( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 488d32cb82c..5c850c25602 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -747,11 +747,13 @@ def make_stats( ) -> Optional[SchedulerStats]: if not self.log_stats: return None + prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() + assert prefix_cache_stats is not None return SchedulerStats( num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), gpu_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), + prefix_cache_stats=prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, )