[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -20,7 +20,7 @@ import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from collections import defaultdict
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
@@ -1517,13 +1517,13 @@ class AtomicCounter:
|
||||
|
||||
|
||||
# Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||
class LazyDict(Mapping, Generic[T]):
|
||||
class LazyDict(Mapping[str, T], Generic[T]):
|
||||
|
||||
def __init__(self, factory: Dict[str, Callable[[], T]]):
|
||||
self._factory = factory
|
||||
self._dict: Dict[str, T] = {}
|
||||
|
||||
def __getitem__(self, key) -> T:
|
||||
def __getitem__(self, key: str) -> T:
|
||||
if key not in self._dict:
|
||||
if key not in self._factory:
|
||||
raise KeyError(key)
|
||||
@@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
|
||||
return len(self._factory)
|
||||
|
||||
|
||||
class ClassRegistry(UserDict[type[T], _V]):
|
||||
|
||||
def __getitem__(self, key: type[T]) -> _V:
|
||||
for cls in key.mro():
|
||||
if cls in self.data:
|
||||
return self.data[cls]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
if not isinstance(key, type):
|
||||
return False
|
||||
|
||||
return any(cls in self.data for cls in key.mro())
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
|
||||
Reference in New Issue
Block a user