[Misc] Rename embedding classes to pooling (#10801)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -24,7 +24,7 @@ from .adapters import as_embedding_model
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
from .interfaces_base import is_pooling_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
class _ModelInfo:
|
||||
architecture: str
|
||||
is_text_generation_model: bool
|
||||
is_embedding_model: bool
|
||||
is_pooling_model: bool
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_pp: bool
|
||||
@@ -220,19 +220,19 @@ class _ModelInfo:
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
is_embedding_model_ = is_embedding_model(model)
|
||||
if not is_embedding_model_:
|
||||
is_pooling_model_ = is_pooling_model(model)
|
||||
if not is_pooling_model_:
|
||||
try:
|
||||
as_embedding_model(model)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
is_embedding_model_ = True
|
||||
is_pooling_model_ = True
|
||||
|
||||
return _ModelInfo(
|
||||
architecture=model.__name__,
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_embedding_model=is_embedding_model_,
|
||||
is_pooling_model=is_pooling_model_,
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
@@ -441,12 +441,12 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_text_generation_model
|
||||
|
||||
def is_embedding_model(
|
||||
def is_pooling_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_embedding_model
|
||||
return model_cls.is_pooling_model
|
||||
|
||||
def is_cross_encoder_model(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user