[Misc] Move LRUCache into its own file (#26342)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-07 23:08:40 +08:00
committed by GitHub
parent 6f59beaf0b
commit c0a7b89d8e
6 changed files with 349 additions and 378 deletions

125
tests/utils_/test_cache.py Normal file
View File

@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.utils.cache import CacheInfo, LRUCache
class TestLRUCache(LRUCache):
def _on_remove(self, key, value):
if not hasattr(self, "_remove_counter"):
self._remove_counter = 0
self._remove_counter += 1
def test_lru_cache():
cache = TestLRUCache(3)
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
cache.put(1, 1)
assert len(cache) == 1
cache.put(1, 1)
assert len(cache) == 1
cache.put(2, 2)
assert len(cache) == 2
cache.put(3, 3)
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache.put(4, 4)
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache.get(2) == 2
assert cache.stat() == CacheInfo(hits=1, total=1)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
assert cache[2] == 2
assert cache.stat() == CacheInfo(hits=2, total=2)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
cache.put(5, 5)
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
assert cache.pop(5) == 5
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
assert cache.get(-1) is None
assert cache.stat() == CacheInfo(hits=2, total=3)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.get(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.put(6, 6)
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
cache.remove_oldest()
assert len(cache) == 2
assert set(cache.cache) == {2, 6}
assert cache._remove_counter == 4
cache.clear()
assert len(cache) == 0
assert cache._remove_counter == 6
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
cache._remove_counter = 0
cache[1] = 1
assert len(cache) == 1
cache[1] = 1
assert len(cache) == 1
cache[2] = 2
assert len(cache) == 2
cache[3] = 3
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache[4] = 4
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache[2] == 2
cache[5] = 5
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
del cache[5]
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache[6] = 6
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache

View File

@@ -23,11 +23,8 @@ from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
# isort: off
from vllm.utils import (
CacheInfo,
FlexibleArgumentParser,
LRUCache,
MemorySnapshot,
PlaceholderModule,
bind_kv_cache,
@@ -50,7 +47,6 @@ from vllm.utils import (
unique_filepath,
)
# isort: on
from ..utils import create_new_process_for_each_test, error_on_warning
@@ -557,128 +553,6 @@ def test_bind_kv_cache_pp():
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
class TestLRUCache(LRUCache):
def _on_remove(self, key, value):
if not hasattr(self, "_remove_counter"):
self._remove_counter = 0
self._remove_counter += 1
def test_lru_cache():
cache = TestLRUCache(3)
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
cache.put(1, 1)
assert len(cache) == 1
cache.put(1, 1)
assert len(cache) == 1
cache.put(2, 2)
assert len(cache) == 2
cache.put(3, 3)
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache.put(4, 4)
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache.get(2) == 2
assert cache.stat() == CacheInfo(hits=1, total=1)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
assert cache[2] == 2
assert cache.stat() == CacheInfo(hits=2, total=2)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
cache.put(5, 5)
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
assert cache.pop(5) == 5
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
assert cache.get(-1) is None
assert cache.stat() == CacheInfo(hits=2, total=3)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.get(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.put(6, 6)
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
cache.remove_oldest()
assert len(cache) == 2
assert set(cache.cache) == {2, 6}
assert cache._remove_counter == 4
cache.clear()
assert len(cache) == 0
assert cache._remove_counter == 6
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
cache._remove_counter = 0
cache[1] = 1
assert len(cache) == 1
cache[1] = 1
assert len(cache) == 1
cache[2] = 2
assert len(cache) == 2
cache[3] = 3
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache[4] = 4
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache[2] == 2
cache[5] = 5
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
del cache[5]
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache[6] = 6
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
@pytest.mark.parametrize(
("src_dtype", "tgt_dtype", "expected_result"),
[