[Misc] Rename embedding classes to pooling (#10801)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-01 14:36:51 +08:00
committed by GitHub
parent f877a7d12a
commit d2f058e76c
25 changed files with 166 additions and 123 deletions

View File

@@ -24,7 +24,7 @@ from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free,
supports_cross_encoding, supports_multimodal,
supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model
from .interfaces_base import is_pooling_model, is_text_generation_model
logger = init_logger(__name__)
@@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class _ModelInfo:
architecture: str
is_text_generation_model: bool
is_embedding_model: bool
is_pooling_model: bool
supports_cross_encoding: bool
supports_multimodal: bool
supports_pp: bool
@@ -220,19 +220,19 @@ class _ModelInfo:
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model_ = is_embedding_model(model)
if not is_embedding_model_:
is_pooling_model_ = is_pooling_model(model)
if not is_pooling_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_embedding_model_ = True
is_pooling_model_ = True
return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model_,
is_pooling_model=is_pooling_model_,
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
@@ -441,12 +441,12 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model
def is_embedding_model(
def is_pooling_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_embedding_model
return model_cls.is_pooling_model
def is_cross_encoder_model(
self,