[Model] Replace embedding models with pooling adapter (#10769)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-01 08:02:54 +08:00
committed by GitHub
parent 7e4bbda573
commit 133707123e
32 changed files with 383 additions and 319 deletions

View File

@@ -20,7 +20,7 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections import defaultdict
from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from functools import lru_cache, partial, wraps
from platform import uname
@@ -1517,13 +1517,13 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]):
class LazyDict(Mapping[str, T], Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]):
self._factory = factory
self._dict: Dict[str, T] = {}
def __getitem__(self, key) -> T:
def __getitem__(self, key: str) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
@@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
return len(self._factory)
class ClassRegistry(UserDict[type[T], _V]):
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
if not isinstance(key, type):
return False
return any(cls in self.data for cls in key.mro())
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""
Create a weak reference to a tensor.