@@ -39,8 +39,9 @@ def __init__(
39
39
40
40
self .enable_caching = enable_caching
41
41
self .caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
42
- # FIXME: make prefix cache stats conditional on log_stats
43
42
self .log_stats = log_stats
43
+ # FIXME: make prefix cache stats conditional on log_stats
44
+ self .prefix_cache_stats = PrefixCacheStats () if log_stats else None
44
45
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
45
46
# blocks for each request. For example, when a request reaches the end
46
47
# of its block table, we preallocate N blocks in advance. This way, we
@@ -79,7 +80,6 @@ def __init__(
79
80
# This is only used to track the RUNNING requests, we do not track the
80
81
# data for reempted ones.
81
82
self .num_cached_block : dict [str , int ] = {}
82
- self .prefix_cache_stats = PrefixCacheStats ()
83
83
84
84
@property
85
85
def usage (self ) -> float :
@@ -90,12 +90,14 @@ def usage(self) -> float:
90
90
"""
91
91
return self .block_pool .get_usage ()
92
92
93
- def make_prefix_cache_stats (self ) -> PrefixCacheStats :
93
+ def make_prefix_cache_stats (self ) -> Optional [ PrefixCacheStats ] :
94
94
"""Get (and reset) the prefix cache stats.
95
95
96
96
Returns:
97
- The current prefix caching stats.
97
+ The current prefix caching stats, or None if logging is disabled .
98
98
"""
99
+ if not self .log_stats :
100
+ return None
99
101
stats = self .prefix_cache_stats
100
102
self .prefix_cache_stats = PrefixCacheStats ()
101
103
return stats
@@ -125,7 +127,9 @@ def get_computed_blocks(
125
127
self .block_size , request )
126
128
self .req_to_block_hashes [request .request_id ] = block_hashes
127
129
128
- self .prefix_cache_stats .requests += 1
130
+ if self .log_stats :
131
+ assert self .prefix_cache_stats is not None
132
+ self .prefix_cache_stats .requests += 1
129
133
# When the request requires prompt logprobs, we skip prefix caching.
130
134
if request .sampling_params .prompt_logprobs is not None :
131
135
return [], 0
@@ -145,8 +149,10 @@ def get_computed_blocks(
145
149
146
150
computed_blocks = (
147
151
self .specialized_manager .find_longest_cache_hit (block_hashes ))
148
- self .prefix_cache_stats .queries += len (block_hashes )
149
- self .prefix_cache_stats .hits += len (computed_blocks )
152
+ if self .log_stats :
153
+ assert self .prefix_cache_stats is not None
154
+ self .prefix_cache_stats .queries += len (block_hashes )
155
+ self .prefix_cache_stats .hits += len (computed_blocks )
150
156
151
157
if last_block_hash is not None :
152
158
# Add back the last block hash if it was removed.
@@ -317,17 +323,19 @@ def free(self, request: Request) -> None:
317
323
318
324
def reset_prefix_cache (self ) -> bool :
319
325
"""Reset prefix cache. This function may be used in RLHF
320
- flows to invalid prefix caching after the weights are updated,
326
+ flows to invalidate prefix caching after the weights are updated,
321
327
or used for resetting prefix caching status for benchmarking.
322
328
323
329
Returns:
324
330
bool: True if the prefix cache is successfully reset,
325
331
False otherwise.
326
332
"""
327
- if self .block_pool .reset_prefix_cache ():
333
+ if not self .block_pool .reset_prefix_cache ():
334
+ return False
335
+ if self .log_stats :
336
+ assert self .prefix_cache_stats is not None
328
337
self .prefix_cache_stats .reset = True
329
- return True
330
- return False
338
+ return True
331
339
332
340
def get_num_common_prefix_blocks (
333
341
self ,
0 commit comments