[LoRA] Add support for pinning lora adapters in the LRU cache (#5603)

This commit is contained in:
rohithkrn
2024-06-21 15:42:46 -07:00
committed by GitHub
parent 7187507301
commit f5dda63eb5
13 changed files with 171 additions and 5 deletions

View File

@@ -15,7 +15,7 @@ from collections import defaultdict
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union)
import numpy as np
@@ -44,6 +44,13 @@ K = TypeVar("K")
T = TypeVar("T")
class _Sentinel:
...
ALL_PINNED_SENTINEL = _Sentinel()
class Device(enum.Enum):
GPU = enum.auto()
CPU = enum.auto()
@@ -67,6 +74,7 @@ class LRUCache(Generic[T]):
def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
@@ -102,14 +110,36 @@ class LRUCache(Generic[T]):
self.cache.move_to_end(key)
self._remove_old_if_needed()
def pin(self, key: Hashable) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if key not in self.cache:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)
def _unpin(self, key: Hashable) -> None:
self.pinned_items.remove(key)
def _on_remove(self, key: Hashable, value: Optional[T]):
pass
def remove_oldest(self):
def remove_oldest(self, remove_pinned=False):
if not self.cache:
return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value)
if not remove_pinned:
# pop the oldest item in the cache that is not pinned
lru_key = next(
(key for key in self.cache if key not in self.pinned_items),
ALL_PINNED_SENTINEL)
if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError("All items are pinned, "
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
@@ -120,13 +150,16 @@ class LRUCache(Generic[T]):
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self):
while len(self.cache) > 0:
self.remove_oldest()
self.remove_oldest(remove_pinned=True)
self.cache.clear()