[Bugfix] Multi-modal caches not acting like LRU caches (#16593)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-15 00:24:16 +08:00
committed by GitHub
parent 6bf27affb6
commit aa29841ede
4 changed files with 187 additions and 126 deletions

View File

@@ -236,6 +236,12 @@ class CacheInfo(NamedTuple):
return self.hits / self.total
def __sub__(self, other: CacheInfo):
return CacheInfo(
hits=self.hits - other.hits,
total=self.total - other.total,
)
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
@@ -243,15 +249,26 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
capacity: float,
getsizeof: Optional[Callable[[_V], float]] = None):
super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]()
self.capacity = capacity
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
value = super().__getitem__(key)
if update_info:
self._hits += 1
self._total += 1
return value
def __delitem__(self, key: _K) -> None:
run_on_remove = key in self
value = self.__getitem__(key)
value = self.__getitem__(key,
update_info=False) # type: ignore[call-arg]
super().__delitem__(key)
if key in self.pinned_items:
# Todo: add warning to inform that del pinned item
@@ -271,8 +288,32 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
"""Return the internal order dictionary (read-only)."""
return MappingProxyType(self._LRUCache__order) # type: ignore
def stat(self) -> CacheInfo:
return CacheInfo(hits=self._hits, total=self._total)
@property
def capacity(self) -> float:
return self.maxsize
@property
def usage(self) -> float:
if self.maxsize == 0:
return 0
return self.currsize / self.maxsize
def stat(self, *, delta: bool = False) -> CacheInfo:
"""
Gets the cumulative number of hits and queries against this cache.
If :code:`delta=True`, instead gets these statistics
since the last call that also passed :code:`delta=True`.
"""
info = CacheInfo(hits=self._hits, total=self._total)
if delta:
info_delta = info - self._last_info
self._last_info = info
info = info_delta
return info
def touch(self, key: _K) -> None:
self._LRUCache__update(key) # type: ignore
@@ -292,7 +333,8 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key in self:
value = self.__getitem__(key)
value = self.__getitem__(
key, update_info=False) # type: ignore[call-arg]
self._hits += 1
else:
@@ -317,8 +359,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
if key not in self:
return default
value = self[key]
del self[key]
value = self.__getitem__(key,
update_info=False) # type: ignore[call-arg]
self.__delitem__(key)
return value
def put(self, key: _K, value: _V) -> None:
@@ -353,10 +396,6 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
while self.currsize > self.capacity:
self.remove_oldest()
def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)
def popitem(self, remove_pinned: bool = False):
"""Remove and return the `(key, value)` pair least recently used."""
if not remove_pinned:
@@ -372,6 +411,14 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
value = self.pop(cast(_K, lru_key))
return (lru_key, value)
def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
class PyObjectCache:
"""Used to cache python objects to avoid object allocations

View File

@@ -50,7 +50,7 @@ class MirroredProcessingCache:
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_hash in self.mm_cache:
if self.mm_cache.get(mm_hash) is not None:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input