[Misc] Move LRUCache into its own file (#26342)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
125
tests/utils_/test_cache.py
Normal file
125
tests/utils_/test_cache.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from vllm.utils.cache import CacheInfo, LRUCache
|
||||||
|
|
||||||
|
|
||||||
|
class TestLRUCache(LRUCache):
|
||||||
|
def _on_remove(self, key, value):
|
||||||
|
if not hasattr(self, "_remove_counter"):
|
||||||
|
self._remove_counter = 0
|
||||||
|
self._remove_counter += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_cache():
|
||||||
|
cache = TestLRUCache(3)
|
||||||
|
assert cache.stat() == CacheInfo(hits=0, total=0)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
|
||||||
|
|
||||||
|
cache.put(1, 1)
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.put(1, 1)
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.put(2, 2)
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache.put(3, 3)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {1, 2, 3}
|
||||||
|
|
||||||
|
cache.put(4, 4)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 3, 4}
|
||||||
|
assert cache._remove_counter == 1
|
||||||
|
|
||||||
|
assert cache.get(2) == 2
|
||||||
|
assert cache.stat() == CacheInfo(hits=1, total=1)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
|
||||||
|
|
||||||
|
assert cache[2] == 2
|
||||||
|
assert cache.stat() == CacheInfo(hits=2, total=2)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
|
||||||
|
|
||||||
|
cache.put(5, 5)
|
||||||
|
assert set(cache.cache) == {2, 4, 5}
|
||||||
|
assert cache._remove_counter == 2
|
||||||
|
|
||||||
|
assert cache.pop(5) == 5
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
assert cache.get(-1) is None
|
||||||
|
assert cache.stat() == CacheInfo(hits=2, total=3)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
|
||||||
|
|
||||||
|
cache.pop(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.get(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.put(6, 6)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 4, 6}
|
||||||
|
assert 2 in cache
|
||||||
|
assert 4 in cache
|
||||||
|
assert 6 in cache
|
||||||
|
|
||||||
|
cache.remove_oldest()
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 6}
|
||||||
|
assert cache._remove_counter == 4
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
assert len(cache) == 0
|
||||||
|
assert cache._remove_counter == 6
|
||||||
|
assert cache.stat() == CacheInfo(hits=0, total=0)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
|
||||||
|
|
||||||
|
cache._remove_counter = 0
|
||||||
|
|
||||||
|
cache[1] = 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache[1] = 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache[2] = 2
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache[3] = 3
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {1, 2, 3}
|
||||||
|
|
||||||
|
cache[4] = 4
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 3, 4}
|
||||||
|
assert cache._remove_counter == 1
|
||||||
|
assert cache[2] == 2
|
||||||
|
|
||||||
|
cache[5] = 5
|
||||||
|
assert set(cache.cache) == {2, 4, 5}
|
||||||
|
assert cache._remove_counter == 2
|
||||||
|
|
||||||
|
del cache[5]
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.pop(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache[6] = 6
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 4, 6}
|
||||||
|
assert 2 in cache
|
||||||
|
assert 4 in cache
|
||||||
|
assert 6 in cache
|
||||||
@@ -23,11 +23,8 @@ from vllm_test_utils.monitor import monitor
|
|||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
||||||
|
|
||||||
# isort: off
|
|
||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
CacheInfo,
|
|
||||||
FlexibleArgumentParser,
|
FlexibleArgumentParser,
|
||||||
LRUCache,
|
|
||||||
MemorySnapshot,
|
MemorySnapshot,
|
||||||
PlaceholderModule,
|
PlaceholderModule,
|
||||||
bind_kv_cache,
|
bind_kv_cache,
|
||||||
@@ -50,7 +47,6 @@ from vllm.utils import (
|
|||||||
unique_filepath,
|
unique_filepath,
|
||||||
)
|
)
|
||||||
|
|
||||||
# isort: on
|
|
||||||
from ..utils import create_new_process_for_each_test, error_on_warning
|
from ..utils import create_new_process_for_each_test, error_on_warning
|
||||||
|
|
||||||
|
|
||||||
@@ -557,128 +553,6 @@ def test_bind_kv_cache_pp():
|
|||||||
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
|
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
|
||||||
|
|
||||||
|
|
||||||
class TestLRUCache(LRUCache):
|
|
||||||
def _on_remove(self, key, value):
|
|
||||||
if not hasattr(self, "_remove_counter"):
|
|
||||||
self._remove_counter = 0
|
|
||||||
self._remove_counter += 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_lru_cache():
|
|
||||||
cache = TestLRUCache(3)
|
|
||||||
assert cache.stat() == CacheInfo(hits=0, total=0)
|
|
||||||
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
|
|
||||||
|
|
||||||
cache.put(1, 1)
|
|
||||||
assert len(cache) == 1
|
|
||||||
|
|
||||||
cache.put(1, 1)
|
|
||||||
assert len(cache) == 1
|
|
||||||
|
|
||||||
cache.put(2, 2)
|
|
||||||
assert len(cache) == 2
|
|
||||||
|
|
||||||
cache.put(3, 3)
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {1, 2, 3}
|
|
||||||
|
|
||||||
cache.put(4, 4)
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {2, 3, 4}
|
|
||||||
assert cache._remove_counter == 1
|
|
||||||
|
|
||||||
assert cache.get(2) == 2
|
|
||||||
assert cache.stat() == CacheInfo(hits=1, total=1)
|
|
||||||
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
|
|
||||||
|
|
||||||
assert cache[2] == 2
|
|
||||||
assert cache.stat() == CacheInfo(hits=2, total=2)
|
|
||||||
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
|
|
||||||
|
|
||||||
cache.put(5, 5)
|
|
||||||
assert set(cache.cache) == {2, 4, 5}
|
|
||||||
assert cache._remove_counter == 2
|
|
||||||
|
|
||||||
assert cache.pop(5) == 5
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
assert cache.get(-1) is None
|
|
||||||
assert cache.stat() == CacheInfo(hits=2, total=3)
|
|
||||||
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
|
|
||||||
|
|
||||||
cache.pop(10)
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache.get(10)
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache.put(6, 6)
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {2, 4, 6}
|
|
||||||
assert 2 in cache
|
|
||||||
assert 4 in cache
|
|
||||||
assert 6 in cache
|
|
||||||
|
|
||||||
cache.remove_oldest()
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 6}
|
|
||||||
assert cache._remove_counter == 4
|
|
||||||
|
|
||||||
cache.clear()
|
|
||||||
assert len(cache) == 0
|
|
||||||
assert cache._remove_counter == 6
|
|
||||||
assert cache.stat() == CacheInfo(hits=0, total=0)
|
|
||||||
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
|
|
||||||
|
|
||||||
cache._remove_counter = 0
|
|
||||||
|
|
||||||
cache[1] = 1
|
|
||||||
assert len(cache) == 1
|
|
||||||
|
|
||||||
cache[1] = 1
|
|
||||||
assert len(cache) == 1
|
|
||||||
|
|
||||||
cache[2] = 2
|
|
||||||
assert len(cache) == 2
|
|
||||||
|
|
||||||
cache[3] = 3
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {1, 2, 3}
|
|
||||||
|
|
||||||
cache[4] = 4
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {2, 3, 4}
|
|
||||||
assert cache._remove_counter == 1
|
|
||||||
assert cache[2] == 2
|
|
||||||
|
|
||||||
cache[5] = 5
|
|
||||||
assert set(cache.cache) == {2, 4, 5}
|
|
||||||
assert cache._remove_counter == 2
|
|
||||||
|
|
||||||
del cache[5]
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache.pop(10)
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache[6] = 6
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {2, 4, 6}
|
|
||||||
assert 2 in cache
|
|
||||||
assert 4 in cache
|
|
||||||
assert 6 in cache
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("src_dtype", "tgt_dtype", "expected_result"),
|
("src_dtype", "tgt_dtype", "expected_result"),
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ from vllm.model_executor.models.interfaces import is_pooling_model
|
|||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
||||||
from vllm.model_executor.utils import get_packed_modules_mapping
|
from vllm.model_executor.utils import get_packed_modules_mapping
|
||||||
from vllm.utils import LRUCache, is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
from vllm.utils.cache import LRUCache
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ from vllm.distributed.device_communicators.shm_object_storage import (
|
|||||||
)
|
)
|
||||||
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
|
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import GiB_bytes, LRUCache, MiB_bytes
|
from vllm.utils import GiB_bytes, MiB_bytes
|
||||||
|
from vllm.utils.cache import LRUCache
|
||||||
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
||||||
|
|
||||||
from .inputs import (
|
from .inputs import (
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ from collections.abc import (
|
|||||||
Hashable,
|
Hashable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
KeysView,
|
|
||||||
Mapping,
|
Mapping,
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
@@ -60,24 +59,19 @@ from concurrent.futures.process import ProcessPoolExecutor
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import MappingProxyType
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
NamedTuple,
|
|
||||||
TextIO,
|
TextIO,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
overload,
|
|
||||||
)
|
)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import cachetools
|
|
||||||
import cbor2
|
import cbor2
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -183,13 +177,6 @@ U = TypeVar("U")
|
|||||||
|
|
||||||
_K = TypeVar("_K", bound=Hashable)
|
_K = TypeVar("_K", bound=Hashable)
|
||||||
_V = TypeVar("_V")
|
_V = TypeVar("_V")
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
|
|
||||||
class _Sentinel: ...
|
|
||||||
|
|
||||||
|
|
||||||
ALL_PINNED_SENTINEL = _Sentinel()
|
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
@@ -215,243 +202,6 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
class _MappingOrderCacheView(UserDict[_K, _V]):
|
|
||||||
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
|
|
||||||
super().__init__(data)
|
|
||||||
self.ordered_keys = ordered_keys
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[_K]:
|
|
||||||
return iter(self.ordered_keys)
|
|
||||||
|
|
||||||
def keys(self) -> KeysView[_K]:
|
|
||||||
return KeysView(self.ordered_keys)
|
|
||||||
|
|
||||||
|
|
||||||
class CacheInfo(NamedTuple):
|
|
||||||
hits: int
|
|
||||||
total: int
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hit_ratio(self) -> float:
|
|
||||||
if self.total == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return self.hits / self.total
|
|
||||||
|
|
||||||
def __sub__(self, other: CacheInfo):
|
|
||||||
return CacheInfo(
|
|
||||||
hits=self.hits - other.hits,
|
|
||||||
total=self.total - other.total,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|
||||||
def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
|
|
||||||
super().__init__(capacity, getsizeof)
|
|
||||||
|
|
||||||
self.pinned_items = set[_K]()
|
|
||||||
|
|
||||||
self._hits = 0
|
|
||||||
self._total = 0
|
|
||||||
self._last_info = CacheInfo(hits=0, total=0)
|
|
||||||
|
|
||||||
def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
|
|
||||||
value = super().__getitem__(key)
|
|
||||||
|
|
||||||
if update_info:
|
|
||||||
self._hits += 1
|
|
||||||
self._total += 1
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def __delitem__(self, key: _K) -> None:
|
|
||||||
run_on_remove = key in self
|
|
||||||
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
|
||||||
super().__delitem__(key)
|
|
||||||
if key in self.pinned_items:
|
|
||||||
# Todo: add warning to inform that del pinned item
|
|
||||||
self._unpin(key)
|
|
||||||
if run_on_remove:
|
|
||||||
self._on_remove(key, value)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache(self) -> Mapping[_K, _V]:
|
|
||||||
"""Return the internal cache dictionary in order (read-only)."""
|
|
||||||
return _MappingOrderCacheView(
|
|
||||||
self._Cache__data, # type: ignore
|
|
||||||
self.order,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def order(self) -> Mapping[_K, None]:
|
|
||||||
"""Return the internal order dictionary (read-only)."""
|
|
||||||
return MappingProxyType(self._LRUCache__order) # type: ignore
|
|
||||||
|
|
||||||
@property
|
|
||||||
def capacity(self) -> float:
|
|
||||||
return self.maxsize
|
|
||||||
|
|
||||||
@property
|
|
||||||
def usage(self) -> float:
|
|
||||||
if self.maxsize == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return self.currsize / self.maxsize
|
|
||||||
|
|
||||||
def stat(self, *, delta: bool = False) -> CacheInfo:
|
|
||||||
"""
|
|
||||||
Gets the cumulative number of hits and queries against this cache.
|
|
||||||
|
|
||||||
If `delta=True`, instead gets these statistics
|
|
||||||
since the last call that also passed `delta=True`.
|
|
||||||
"""
|
|
||||||
info = CacheInfo(hits=self._hits, total=self._total)
|
|
||||||
|
|
||||||
if delta:
|
|
||||||
info_delta = info - self._last_info
|
|
||||||
self._last_info = info
|
|
||||||
info = info_delta
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def touch(self, key: _K) -> None:
|
|
||||||
try:
|
|
||||||
self._LRUCache__order.move_to_end(key) # type: ignore
|
|
||||||
except KeyError:
|
|
||||||
self._LRUCache__order[key] = None # type: ignore
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: _K, /) -> _V | None: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ...
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self, key: _K, /, default: Union[_V, _T] | None = None
|
|
||||||
) -> Union[_V, _T] | None:
|
|
||||||
value: Union[_V, _T] | None
|
|
||||||
if key in self:
|
|
||||||
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
self._hits += 1
|
|
||||||
else:
|
|
||||||
value = default
|
|
||||||
|
|
||||||
self._total += 1
|
|
||||||
return value
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def pop(self, key: _K) -> _V: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...
|
|
||||||
|
|
||||||
def pop(
|
|
||||||
self, key: _K, default: Union[_V, _T] | None = None
|
|
||||||
) -> Union[_V, _T] | None:
|
|
||||||
value: Union[_V, _T] | None
|
|
||||||
if key not in self:
|
|
||||||
return default
|
|
||||||
|
|
||||||
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
|
||||||
self.__delitem__(key)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def put(self, key: _K, value: _V) -> None:
|
|
||||||
self.__setitem__(key, value)
|
|
||||||
|
|
||||||
def pin(self, key: _K) -> None:
|
|
||||||
"""
|
|
||||||
Pins a key in the cache preventing it from being
|
|
||||||
evicted in the LRU order.
|
|
||||||
"""
|
|
||||||
if key not in self:
|
|
||||||
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
|
||||||
self.pinned_items.add(key)
|
|
||||||
|
|
||||||
def _unpin(self, key: _K) -> None:
|
|
||||||
"""
|
|
||||||
Unpins a key in the cache allowing it to be
|
|
||||||
evicted in the LRU order.
|
|
||||||
"""
|
|
||||||
self.pinned_items.remove(key)
|
|
||||||
|
|
||||||
def _on_remove(self, key: _K, value: _V | None) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
|
||||||
if len(self) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.popitem(remove_pinned=remove_pinned)
|
|
||||||
|
|
||||||
def _remove_old_if_needed(self) -> None:
|
|
||||||
while self.currsize > self.capacity:
|
|
||||||
self.remove_oldest()
|
|
||||||
|
|
||||||
def popitem(self, remove_pinned: bool = False):
|
|
||||||
"""Remove and return the `(key, value)` pair least recently used."""
|
|
||||||
if not remove_pinned:
|
|
||||||
# pop the oldest item in the cache that is not pinned
|
|
||||||
lru_key = next(
|
|
||||||
(key for key in self.order if key not in self.pinned_items),
|
|
||||||
ALL_PINNED_SENTINEL,
|
|
||||||
)
|
|
||||||
if lru_key is ALL_PINNED_SENTINEL:
|
|
||||||
raise RuntimeError(
|
|
||||||
"All items are pinned, cannot remove oldest from the cache."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
lru_key = next(iter(self.order))
|
|
||||||
value = self.pop(cast(_K, lru_key))
|
|
||||||
return (lru_key, value)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
while len(self) > 0:
|
|
||||||
self.remove_oldest(remove_pinned=True)
|
|
||||||
|
|
||||||
self._hits = 0
|
|
||||||
self._total = 0
|
|
||||||
self._last_info = CacheInfo(hits=0, total=0)
|
|
||||||
|
|
||||||
|
|
||||||
class PyObjectCache:
|
|
||||||
"""Used to cache python objects to avoid object allocations
|
|
||||||
across scheduler iterations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, obj_builder):
|
|
||||||
self._obj_builder = obj_builder
|
|
||||||
self._index = 0
|
|
||||||
|
|
||||||
self._obj_cache = []
|
|
||||||
for _ in range(128):
|
|
||||||
self._obj_cache.append(self._obj_builder())
|
|
||||||
|
|
||||||
def _grow_cache(self):
|
|
||||||
# Double the size of the cache
|
|
||||||
num_objs = len(self._obj_cache)
|
|
||||||
for _ in range(num_objs):
|
|
||||||
self._obj_cache.append(self._obj_builder())
|
|
||||||
|
|
||||||
def get_object(self):
|
|
||||||
"""Returns a pre-allocated cached object. If there is not enough
|
|
||||||
objects, then the cache size will double.
|
|
||||||
"""
|
|
||||||
if self._index >= len(self._obj_cache):
|
|
||||||
self._grow_cache()
|
|
||||||
assert self._index < len(self._obj_cache)
|
|
||||||
|
|
||||||
obj = self._obj_cache[self._index]
|
|
||||||
self._index += 1
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Makes all cached-objects available for the next scheduler iteration."""
|
|
||||||
self._index = 0
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
"""Returns the maximum shared memory per thread block in bytes."""
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
|
|||||||
220
vllm/utils/cache.py
Normal file
220
vllm/utils/cache.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import UserDict
|
||||||
|
from collections.abc import Hashable, Iterator, KeysView, Mapping
|
||||||
|
from types import MappingProxyType
|
||||||
|
from typing import Callable, Generic, NamedTuple, TypeVar, Union, cast, overload
|
||||||
|
|
||||||
|
import cachetools
|
||||||
|
|
||||||
|
_K = TypeVar("_K", bound=Hashable)
|
||||||
|
_V = TypeVar("_V")
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
class _Sentinel: ...
|
||||||
|
|
||||||
|
|
||||||
|
ALL_PINNED_SENTINEL = _Sentinel()
|
||||||
|
|
||||||
|
|
||||||
|
class _MappingOrderCacheView(UserDict[_K, _V]):
|
||||||
|
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
|
||||||
|
super().__init__(data)
|
||||||
|
self.ordered_keys = ordered_keys
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[_K]:
|
||||||
|
return iter(self.ordered_keys)
|
||||||
|
|
||||||
|
def keys(self) -> KeysView[_K]:
|
||||||
|
return KeysView(self.ordered_keys)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheInfo(NamedTuple):
|
||||||
|
hits: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_ratio(self) -> float:
|
||||||
|
if self.total == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self.hits / self.total
|
||||||
|
|
||||||
|
def __sub__(self, other: CacheInfo):
|
||||||
|
return CacheInfo(
|
||||||
|
hits=self.hits - other.hits,
|
||||||
|
total=self.total - other.total,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
||||||
|
def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
|
||||||
|
super().__init__(capacity, getsizeof)
|
||||||
|
|
||||||
|
self.pinned_items = set[_K]()
|
||||||
|
|
||||||
|
self._hits = 0
|
||||||
|
self._total = 0
|
||||||
|
self._last_info = CacheInfo(hits=0, total=0)
|
||||||
|
|
||||||
|
def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
|
||||||
|
value = super().__getitem__(key)
|
||||||
|
|
||||||
|
if update_info:
|
||||||
|
self._hits += 1
|
||||||
|
self._total += 1
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __delitem__(self, key: _K) -> None:
|
||||||
|
run_on_remove = key in self
|
||||||
|
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||||
|
super().__delitem__(key)
|
||||||
|
if key in self.pinned_items:
|
||||||
|
# Todo: add warning to inform that del pinned item
|
||||||
|
self._unpin(key)
|
||||||
|
if run_on_remove:
|
||||||
|
self._on_remove(key, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache(self) -> Mapping[_K, _V]:
|
||||||
|
"""Return the internal cache dictionary in order (read-only)."""
|
||||||
|
return _MappingOrderCacheView(
|
||||||
|
self._Cache__data, # type: ignore
|
||||||
|
self.order,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def order(self) -> Mapping[_K, None]:
|
||||||
|
"""Return the internal order dictionary (read-only)."""
|
||||||
|
return MappingProxyType(self._LRUCache__order) # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capacity(self) -> float:
|
||||||
|
return self.maxsize
|
||||||
|
|
||||||
|
@property
|
||||||
|
def usage(self) -> float:
|
||||||
|
if self.maxsize == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self.currsize / self.maxsize
|
||||||
|
|
||||||
|
def stat(self, *, delta: bool = False) -> CacheInfo:
|
||||||
|
"""
|
||||||
|
Gets the cumulative number of hits and queries against this cache.
|
||||||
|
|
||||||
|
If `delta=True`, instead gets these statistics
|
||||||
|
since the last call that also passed `delta=True`.
|
||||||
|
"""
|
||||||
|
info = CacheInfo(hits=self._hits, total=self._total)
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
info_delta = info - self._last_info
|
||||||
|
self._last_info = info
|
||||||
|
info = info_delta
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def touch(self, key: _K) -> None:
|
||||||
|
try:
|
||||||
|
self._LRUCache__order.move_to_end(key) # type: ignore
|
||||||
|
except KeyError:
|
||||||
|
self._LRUCache__order[key] = None # type: ignore
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: _K, /) -> _V | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ...
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: _K, /, default: Union[_V, _T] | None = None
|
||||||
|
) -> Union[_V, _T] | None:
|
||||||
|
value: Union[_V, _T] | None
|
||||||
|
if key in self:
|
||||||
|
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
self._hits += 1
|
||||||
|
else:
|
||||||
|
value = default
|
||||||
|
|
||||||
|
self._total += 1
|
||||||
|
return value
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def pop(self, key: _K) -> _V: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...
|
||||||
|
|
||||||
|
def pop(
|
||||||
|
self, key: _K, default: Union[_V, _T] | None = None
|
||||||
|
) -> Union[_V, _T] | None:
|
||||||
|
value: Union[_V, _T] | None
|
||||||
|
if key not in self:
|
||||||
|
return default
|
||||||
|
|
||||||
|
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||||
|
self.__delitem__(key)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def put(self, key: _K, value: _V) -> None:
|
||||||
|
self.__setitem__(key, value)
|
||||||
|
|
||||||
|
def pin(self, key: _K) -> None:
|
||||||
|
"""
|
||||||
|
Pins a key in the cache preventing it from being
|
||||||
|
evicted in the LRU order.
|
||||||
|
"""
|
||||||
|
if key not in self:
|
||||||
|
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||||
|
self.pinned_items.add(key)
|
||||||
|
|
||||||
|
def _unpin(self, key: _K) -> None:
|
||||||
|
"""
|
||||||
|
Unpins a key in the cache allowing it to be
|
||||||
|
evicted in the LRU order.
|
||||||
|
"""
|
||||||
|
self.pinned_items.remove(key)
|
||||||
|
|
||||||
|
def _on_remove(self, key: _K, value: _V | None) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
||||||
|
if len(self) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.popitem(remove_pinned=remove_pinned)
|
||||||
|
|
||||||
|
def _remove_old_if_needed(self) -> None:
|
||||||
|
while self.currsize > self.capacity:
|
||||||
|
self.remove_oldest()
|
||||||
|
|
||||||
|
def popitem(self, remove_pinned: bool = False):
|
||||||
|
"""Remove and return the `(key, value)` pair least recently used."""
|
||||||
|
if not remove_pinned:
|
||||||
|
# pop the oldest item in the cache that is not pinned
|
||||||
|
lru_key = next(
|
||||||
|
(key for key in self.order if key not in self.pinned_items),
|
||||||
|
ALL_PINNED_SENTINEL,
|
||||||
|
)
|
||||||
|
if lru_key is ALL_PINNED_SENTINEL:
|
||||||
|
raise RuntimeError(
|
||||||
|
"All items are pinned, cannot remove oldest from the cache."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lru_key = next(iter(self.order))
|
||||||
|
value = self.pop(cast(_K, lru_key))
|
||||||
|
return (lru_key, value)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
while len(self) > 0:
|
||||||
|
self.remove_oldest(remove_pinned=True)
|
||||||
|
|
||||||
|
self._hits = 0
|
||||||
|
self._total = 0
|
||||||
|
self._last_info = CacheInfo(hits=0, total=0)
|
||||||
Reference in New Issue
Block a user