Skip to content

Commit 31d9540

Browse files
DarkLight1337dbyoung18
authored andcommitted
[Bugfix] Multi-modal caches not acting like LRU caches (vllm-project#16593)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent e7b996e commit 31d9540

File tree

4 files changed

+187
-126
lines changed

4 files changed

+187
-126
lines changed

tests/lora/test_utils.py

Lines changed: 0 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from vllm.lora.utils import (get_adapter_absolute_path,
1111
parse_fine_tuned_lora_name, replace_submodule)
12-
from vllm.utils import LRUCache
1312

1413

1514
def test_parse_fine_tuned_lora_name_valid():
@@ -85,114 +84,6 @@ def test_replace_submodule():
8584
assert dict(model.named_modules())["seq1.dense2"] == dense2
8685

8786

88-
class TestLRUCache(LRUCache):
89-
90-
def _on_remove(self, key, value):
91-
if not hasattr(self, "_remove_counter"):
92-
self._remove_counter = 0
93-
self._remove_counter += 1
94-
95-
96-
def test_lru_cache():
97-
cache = TestLRUCache(3)
98-
99-
cache.put(1, 1)
100-
assert len(cache) == 1
101-
102-
cache.put(1, 1)
103-
assert len(cache) == 1
104-
105-
cache.put(2, 2)
106-
assert len(cache) == 2
107-
108-
cache.put(3, 3)
109-
assert len(cache) == 3
110-
assert set(cache.cache) == {1, 2, 3}
111-
112-
cache.put(4, 4)
113-
assert len(cache) == 3
114-
assert set(cache.cache) == {2, 3, 4}
115-
assert cache._remove_counter == 1
116-
assert cache.get(2) == 2
117-
118-
cache.put(5, 5)
119-
assert set(cache.cache) == {2, 4, 5}
120-
assert cache._remove_counter == 2
121-
122-
assert cache.pop(5) == 5
123-
assert len(cache) == 2
124-
assert set(cache.cache) == {2, 4}
125-
assert cache._remove_counter == 3
126-
127-
cache.pop(10)
128-
assert len(cache) == 2
129-
assert set(cache.cache) == {2, 4}
130-
assert cache._remove_counter == 3
131-
132-
cache.get(10)
133-
assert len(cache) == 2
134-
assert set(cache.cache) == {2, 4}
135-
assert cache._remove_counter == 3
136-
137-
cache.put(6, 6)
138-
assert len(cache) == 3
139-
assert set(cache.cache) == {2, 4, 6}
140-
assert 2 in cache
141-
assert 4 in cache
142-
assert 6 in cache
143-
144-
cache.remove_oldest()
145-
assert len(cache) == 2
146-
assert set(cache.cache) == {2, 6}
147-
assert cache._remove_counter == 4
148-
149-
cache.clear()
150-
assert len(cache) == 0
151-
assert cache._remove_counter == 6
152-
153-
cache._remove_counter = 0
154-
155-
cache[1] = 1
156-
assert len(cache) == 1
157-
158-
cache[1] = 1
159-
assert len(cache) == 1
160-
161-
cache[2] = 2
162-
assert len(cache) == 2
163-
164-
cache[3] = 3
165-
assert len(cache) == 3
166-
assert set(cache.cache) == {1, 2, 3}
167-
168-
cache[4] = 4
169-
assert len(cache) == 3
170-
assert set(cache.cache) == {2, 3, 4}
171-
assert cache._remove_counter == 1
172-
assert cache[2] == 2
173-
174-
cache[5] = 5
175-
assert set(cache.cache) == {2, 4, 5}
176-
assert cache._remove_counter == 2
177-
178-
del cache[5]
179-
assert len(cache) == 2
180-
assert set(cache.cache) == {2, 4}
181-
assert cache._remove_counter == 3
182-
183-
cache.pop(10)
184-
assert len(cache) == 2
185-
assert set(cache.cache) == {2, 4}
186-
assert cache._remove_counter == 3
187-
188-
cache[6] = 6
189-
assert len(cache) == 3
190-
assert set(cache.cache) == {2, 4, 6}
191-
assert 2 in cache
192-
assert 4 in cache
193-
assert 6 in cache
194-
195-
19687
# Unit tests for get_adapter_absolute_path
19788
@patch('os.path.isabs')
19889
def test_get_adapter_absolute_path_absolute(mock_isabs):

tests/test_utils.py

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from vllm_test_utils.monitor import monitor
1414

1515
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
16-
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
17-
PlaceholderModule, StoreBoolean, bind_kv_cache,
18-
deprecate_kwargs, get_open_port, memory_profiling,
19-
merge_async_iterators, sha256, supports_kw,
20-
swap_dict_values)
16+
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
17+
MemorySnapshot, PlaceholderModule, StoreBoolean,
18+
bind_kv_cache, deprecate_kwargs, get_open_port,
19+
memory_profiling, merge_async_iterators, sha256,
20+
supports_kw, swap_dict_values)
2121

2222
from .utils import create_new_process_for_each_test, error_on_warning
2323

@@ -417,6 +417,129 @@ def test_bind_kv_cache_pp():
417417
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
418418

419419

420+
class TestLRUCache(LRUCache):
421+
422+
def _on_remove(self, key, value):
423+
if not hasattr(self, "_remove_counter"):
424+
self._remove_counter = 0
425+
self._remove_counter += 1
426+
427+
428+
def test_lru_cache():
429+
cache = TestLRUCache(3)
430+
assert cache.stat() == CacheInfo(hits=0, total=0)
431+
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
432+
433+
cache.put(1, 1)
434+
assert len(cache) == 1
435+
436+
cache.put(1, 1)
437+
assert len(cache) == 1
438+
439+
cache.put(2, 2)
440+
assert len(cache) == 2
441+
442+
cache.put(3, 3)
443+
assert len(cache) == 3
444+
assert set(cache.cache) == {1, 2, 3}
445+
446+
cache.put(4, 4)
447+
assert len(cache) == 3
448+
assert set(cache.cache) == {2, 3, 4}
449+
assert cache._remove_counter == 1
450+
451+
assert cache.get(2) == 2
452+
assert cache.stat() == CacheInfo(hits=1, total=1)
453+
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
454+
455+
assert cache[2] == 2
456+
assert cache.stat() == CacheInfo(hits=2, total=2)
457+
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
458+
459+
cache.put(5, 5)
460+
assert set(cache.cache) == {2, 4, 5}
461+
assert cache._remove_counter == 2
462+
463+
assert cache.pop(5) == 5
464+
assert len(cache) == 2
465+
assert set(cache.cache) == {2, 4}
466+
assert cache._remove_counter == 3
467+
468+
assert cache.get(-1) is None
469+
assert cache.stat() == CacheInfo(hits=2, total=3)
470+
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
471+
472+
cache.pop(10)
473+
assert len(cache) == 2
474+
assert set(cache.cache) == {2, 4}
475+
assert cache._remove_counter == 3
476+
477+
cache.get(10)
478+
assert len(cache) == 2
479+
assert set(cache.cache) == {2, 4}
480+
assert cache._remove_counter == 3
481+
482+
cache.put(6, 6)
483+
assert len(cache) == 3
484+
assert set(cache.cache) == {2, 4, 6}
485+
assert 2 in cache
486+
assert 4 in cache
487+
assert 6 in cache
488+
489+
cache.remove_oldest()
490+
assert len(cache) == 2
491+
assert set(cache.cache) == {2, 6}
492+
assert cache._remove_counter == 4
493+
494+
cache.clear()
495+
assert len(cache) == 0
496+
assert cache._remove_counter == 6
497+
assert cache.stat() == CacheInfo(hits=0, total=0)
498+
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
499+
500+
cache._remove_counter = 0
501+
502+
cache[1] = 1
503+
assert len(cache) == 1
504+
505+
cache[1] = 1
506+
assert len(cache) == 1
507+
508+
cache[2] = 2
509+
assert len(cache) == 2
510+
511+
cache[3] = 3
512+
assert len(cache) == 3
513+
assert set(cache.cache) == {1, 2, 3}
514+
515+
cache[4] = 4
516+
assert len(cache) == 3
517+
assert set(cache.cache) == {2, 3, 4}
518+
assert cache._remove_counter == 1
519+
assert cache[2] == 2
520+
521+
cache[5] = 5
522+
assert set(cache.cache) == {2, 4, 5}
523+
assert cache._remove_counter == 2
524+
525+
del cache[5]
526+
assert len(cache) == 2
527+
assert set(cache.cache) == {2, 4}
528+
assert cache._remove_counter == 3
529+
530+
cache.pop(10)
531+
assert len(cache) == 2
532+
assert set(cache.cache) == {2, 4}
533+
assert cache._remove_counter == 3
534+
535+
cache[6] = 6
536+
assert len(cache) == 3
537+
assert set(cache.cache) == {2, 4, 6}
538+
assert 2 in cache
539+
assert 4 in cache
540+
assert 6 in cache
541+
542+
420543
def test_placeholder_module_error_handling():
421544
placeholder = PlaceholderModule("placeholder_1234")
422545

vllm/utils.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,22 +236,39 @@ def hit_ratio(self) -> float:
236236

237237
return self.hits / self.total
238238

239+
def __sub__(self, other: CacheInfo):
240+
return CacheInfo(
241+
hits=self.hits - other.hits,
242+
total=self.total - other.total,
243+
)
244+
239245

240246
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
241247

242248
def __init__(self,
243249
capacity: float,
244250
getsizeof: Optional[Callable[[_V], float]] = None):
245251
super().__init__(capacity, getsizeof)
252+
246253
self.pinned_items = set[_K]()
247-
self.capacity = capacity
248254

249255
self._hits = 0
250256
self._total = 0
257+
self._last_info = CacheInfo(hits=0, total=0)
258+
259+
def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
260+
value = super().__getitem__(key)
261+
262+
if update_info:
263+
self._hits += 1
264+
self._total += 1
265+
266+
return value
251267

252268
def __delitem__(self, key: _K) -> None:
253269
run_on_remove = key in self
254-
value = self.__getitem__(key)
270+
value = self.__getitem__(key,
271+
update_info=False) # type: ignore[call-arg]
255272
super().__delitem__(key)
256273
if key in self.pinned_items:
257274
# Todo: add warning to inform that del pinned item
@@ -271,8 +288,32 @@ def order(self) -> Mapping[_K, None]:
271288
"""Return the internal order dictionary (read-only)."""
272289
return MappingProxyType(self._LRUCache__order) # type: ignore
273290

274-
def stat(self) -> CacheInfo:
275-
return CacheInfo(hits=self._hits, total=self._total)
291+
@property
292+
def capacity(self) -> float:
293+
return self.maxsize
294+
295+
@property
296+
def usage(self) -> float:
297+
if self.maxsize == 0:
298+
return 0
299+
300+
return self.currsize / self.maxsize
301+
302+
def stat(self, *, delta: bool = False) -> CacheInfo:
303+
"""
304+
Gets the cumulative number of hits and queries against this cache.
305+
306+
If :code:`delta=True`, instead gets these statistics
307+
since the last call that also passed :code:`delta=True`.
308+
"""
309+
info = CacheInfo(hits=self._hits, total=self._total)
310+
311+
if delta:
312+
info_delta = info - self._last_info
313+
self._last_info = info
314+
info = info_delta
315+
316+
return info
276317

277318
def touch(self, key: _K) -> None:
278319
self._LRUCache__update(key) # type: ignore
@@ -292,7 +333,8 @@ def get(self,
292333
_T]] = None) -> Optional[Union[_V, _T]]:
293334
value: Optional[Union[_V, _T]]
294335
if key in self:
295-
value = self.__getitem__(key)
336+
value = self.__getitem__(
337+
key, update_info=False) # type: ignore[call-arg]
296338

297339
self._hits += 1
298340
else:
@@ -317,8 +359,9 @@ def pop(self,
317359
if key not in self:
318360
return default
319361

320-
value = self[key]
321-
del self[key]
362+
value = self.__getitem__(key,
363+
update_info=False) # type: ignore[call-arg]
364+
self.__delitem__(key)
322365
return value
323366

324367
def put(self, key: _K, value: _V) -> None:
@@ -353,10 +396,6 @@ def _remove_old_if_needed(self) -> None:
353396
while self.currsize > self.capacity:
354397
self.remove_oldest()
355398

356-
def clear(self) -> None:
357-
while len(self) > 0:
358-
self.remove_oldest(remove_pinned=True)
359-
360399
def popitem(self, remove_pinned: bool = False):
361400
"""Remove and return the `(key, value)` pair least recently used."""
362401
if not remove_pinned:
@@ -372,6 +411,14 @@ def popitem(self, remove_pinned: bool = False):
372411
value = self.pop(cast(_K, lru_key))
373412
return (lru_key, value)
374413

414+
def clear(self) -> None:
415+
while len(self) > 0:
416+
self.remove_oldest(remove_pinned=True)
417+
418+
self._hits = 0
419+
self._total = 0
420+
self._last_info = CacheInfo(hits=0, total=0)
421+
375422

376423
class PyObjectCache:
377424
"""Used to cache python objects to avoid object allocations

0 commit comments

Comments
 (0)