[Misc] Clean up and consolidate LRUCache (#11339)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -21,14 +21,13 @@ import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections import OrderedDict, UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
Dict, Generator, Generic, Hashable, List, Literal,
|
||||
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
|
||||
overload)
|
||||
Optional, Tuple, Type, TypeVar, Union, overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@@ -154,10 +153,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
}
|
||||
|
||||
P = ParamSpec('P')
|
||||
K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
class _Sentinel:
|
||||
...
|
||||
@@ -190,50 +191,48 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class LRUCache(Generic[T]):
|
||||
class LRUCache(Generic[_K, _V]):
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Hashable, T] = OrderedDict()
|
||||
self.pinned_items: Set[Hashable] = set()
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.cache = OrderedDict[_K, _V]()
|
||||
self.pinned_items = set[_K]()
|
||||
self.capacity = capacity
|
||||
|
||||
def __contains__(self, key: Hashable) -> bool:
|
||||
def __contains__(self, key: _K) -> bool:
|
||||
return key in self.cache
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.cache)
|
||||
|
||||
def __getitem__(self, key: Hashable) -> T:
|
||||
def __getitem__(self, key: _K) -> _V:
|
||||
value = self.cache[key] # Raise KeyError if not exists
|
||||
self.cache.move_to_end(key)
|
||||
return value
|
||||
|
||||
def __setitem__(self, key: Hashable, value: T) -> None:
|
||||
def __setitem__(self, key: _K, value: _V) -> None:
|
||||
self.put(key, value)
|
||||
|
||||
def __delitem__(self, key: Hashable) -> None:
|
||||
def __delitem__(self, key: _K) -> None:
|
||||
self.pop(key)
|
||||
|
||||
def touch(self, key: Hashable) -> None:
|
||||
def touch(self, key: _K) -> None:
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
def get(self,
|
||||
key: Hashable,
|
||||
default_value: Optional[T] = None) -> Optional[T]:
|
||||
value: Optional[T]
|
||||
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
|
||||
value: Optional[_V]
|
||||
if key in self.cache:
|
||||
value = self.cache[key]
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
value = default_value
|
||||
value = default
|
||||
return value
|
||||
|
||||
def put(self, key: Hashable, value: T) -> None:
|
||||
def put(self, key: _K, value: _V) -> None:
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
self._remove_old_if_needed()
|
||||
|
||||
def pin(self, key: Hashable) -> None:
|
||||
def pin(self, key: _K) -> None:
|
||||
"""
|
||||
Pins a key in the cache preventing it from being
|
||||
evicted in the LRU order.
|
||||
@@ -242,13 +241,13 @@ class LRUCache(Generic[T]):
|
||||
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||
self.pinned_items.add(key)
|
||||
|
||||
def _unpin(self, key: Hashable) -> None:
|
||||
def _unpin(self, key: _K) -> None:
|
||||
self.pinned_items.remove(key)
|
||||
|
||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
||||
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
|
||||
pass
|
||||
|
||||
def remove_oldest(self, remove_pinned=False):
|
||||
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
||||
if not self.cache:
|
||||
return
|
||||
|
||||
@@ -262,17 +261,15 @@ class LRUCache(Generic[T]):
|
||||
"cannot remove oldest from the cache.")
|
||||
else:
|
||||
lru_key = next(iter(self.cache))
|
||||
self.pop(lru_key)
|
||||
self.pop(lru_key) # type: ignore
|
||||
|
||||
def _remove_old_if_needed(self) -> None:
|
||||
while len(self.cache) > self.capacity:
|
||||
self.remove_oldest()
|
||||
|
||||
def pop(self,
|
||||
key: Hashable,
|
||||
default_value: Optional[T] = None) -> Optional[T]:
|
||||
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
|
||||
run_on_remove = key in self.cache
|
||||
value: Optional[T] = self.cache.pop(key, default_value)
|
||||
value = self.cache.pop(key, default)
|
||||
# remove from pinned items
|
||||
if key in self.pinned_items:
|
||||
self._unpin(key)
|
||||
@@ -280,7 +277,7 @@ class LRUCache(Generic[T]):
|
||||
self._on_remove(key, value)
|
||||
return value
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
while len(self.cache) > 0:
|
||||
self.remove_oldest(remove_pinned=True)
|
||||
self.cache.clear()
|
||||
@@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||
"""
|
||||
Unlike :class:`itertools.groupby`, groups are not broken by
|
||||
|
||||
Reference in New Issue
Block a user