[Misc] Clean up and consolidate LRUCache (#11339)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-20 00:59:32 +08:00
committed by GitHub
parent e24113a8fe
commit cdf22afdda
5 changed files with 34 additions and 67 deletions

View File

@@ -21,14 +21,13 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
overload)
Optional, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4
import numpy as np
@@ -154,10 +153,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class _Sentinel:
...
@@ -190,50 +191,48 @@ class Counter:
self.counter = 0
class LRUCache(Generic[T]):
class LRUCache(Generic[_K, _V]):
def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
def __init__(self, capacity: int) -> None:
self.cache = OrderedDict[_K, _V]()
self.pinned_items = set[_K]()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
def __contains__(self, key: _K) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> T:
def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: Hashable, value: T) -> None:
def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
def __delitem__(self, key: _K) -> None:
self.pop(key)
def touch(self, key: Hashable) -> None:
def touch(self, key: _K) -> None:
self.cache.move_to_end(key)
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
value: Optional[_V]
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
value = default
return value
def put(self, key: Hashable, value: T) -> None:
def put(self, key: _K, value: _V) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def pin(self, key: Hashable) -> None:
def pin(self, key: _K) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
@@ -242,13 +241,13 @@ class LRUCache(Generic[T]):
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)
def _unpin(self, key: Hashable) -> None:
def _unpin(self, key: _K) -> None:
self.pinned_items.remove(key)
def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass
def remove_oldest(self, remove_pinned=False):
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache:
return
@@ -262,17 +261,15 @@ class LRUCache(Generic[T]):
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)
self.pop(lru_key) # type: ignore
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
@@ -280,7 +277,7 @@ class LRUCache(Generic[T]):
self._on_remove(key, value)
return value
def clear(self):
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
@@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist]
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike :class:`itertools.groupby`, groups are not broken by