[LoRA] Add support for pinning lora adapters in the LRU cache (#5603)
This commit is contained in:
@@ -15,7 +15,7 @@ from collections import defaultdict
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
|
||||
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
@@ -44,6 +44,13 @@ K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _Sentinel:
|
||||
...
|
||||
|
||||
|
||||
ALL_PINNED_SENTINEL = _Sentinel()
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
GPU = enum.auto()
|
||||
CPU = enum.auto()
|
||||
@@ -67,6 +74,7 @@ class LRUCache(Generic[T]):
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Hashable, T] = OrderedDict()
|
||||
self.pinned_items: Set[Hashable] = set()
|
||||
self.capacity = capacity
|
||||
|
||||
def __contains__(self, key: Hashable) -> bool:
|
||||
@@ -102,14 +110,36 @@ class LRUCache(Generic[T]):
|
||||
self.cache.move_to_end(key)
|
||||
self._remove_old_if_needed()
|
||||
|
||||
def pin(self, key: Hashable) -> None:
|
||||
"""
|
||||
Pins a key in the cache preventing it from being
|
||||
evicted in the LRU order.
|
||||
"""
|
||||
if key not in self.cache:
|
||||
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||
self.pinned_items.add(key)
|
||||
|
||||
def _unpin(self, key: Hashable) -> None:
|
||||
self.pinned_items.remove(key)
|
||||
|
||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
||||
pass
|
||||
|
||||
def remove_oldest(self):
|
||||
def remove_oldest(self, remove_pinned=False):
|
||||
if not self.cache:
|
||||
return
|
||||
key, value = self.cache.popitem(last=False)
|
||||
self._on_remove(key, value)
|
||||
|
||||
if not remove_pinned:
|
||||
# pop the oldest item in the cache that is not pinned
|
||||
lru_key = next(
|
||||
(key for key in self.cache if key not in self.pinned_items),
|
||||
ALL_PINNED_SENTINEL)
|
||||
if lru_key is ALL_PINNED_SENTINEL:
|
||||
raise RuntimeError("All items are pinned, "
|
||||
"cannot remove oldest from the cache.")
|
||||
else:
|
||||
lru_key = next(iter(self.cache))
|
||||
self.pop(lru_key)
|
||||
|
||||
def _remove_old_if_needed(self) -> None:
|
||||
while len(self.cache) > self.capacity:
|
||||
@@ -120,13 +150,16 @@ class LRUCache(Generic[T]):
|
||||
default_value: Optional[T] = None) -> Optional[T]:
|
||||
run_on_remove = key in self.cache
|
||||
value: Optional[T] = self.cache.pop(key, default_value)
|
||||
# remove from pinned items
|
||||
if key in self.pinned_items:
|
||||
self._unpin(key)
|
||||
if run_on_remove:
|
||||
self._on_remove(key, value)
|
||||
return value
|
||||
|
||||
def clear(self):
|
||||
while len(self.cache) > 0:
|
||||
self.remove_oldest()
|
||||
self.remove_oldest(remove_pinned=True)
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user