[Model] Replace embedding models with pooling adapter (#10769)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-01 08:02:54 +08:00
committed by GitHub
parent 7e4bbda573
commit 133707123e
32 changed files with 383 additions and 319 deletions

View File

@@ -20,6 +20,7 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free,
supports_cross_encoding, supports_multimodal,
supports_pp)
@@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
# Multiple models share the same architecture, so we include them all
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
if arch == "LlamaForCausalLM"
},
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"MistralModel": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
@@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
}
_CROSS_ENCODER_MODELS = {
@@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
@dataclass(frozen=True)
class _ModelInfo:
architecture: str
is_text_generation_model: bool
is_embedding_model: bool
supports_cross_encoding: bool
@@ -218,9 +220,19 @@ class _ModelInfo:
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model_ = is_embedding_model(model)
if not is_embedding_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_embedding_model_ = True
return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model(model),
is_embedding_model=is_embedding_model_,
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
@@ -399,13 +411,13 @@ class _ModelRegistry:
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
) -> _ModelInfo:
) -> Tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info
return (model_info, arch)
return self._raise_for_unsupported(architectures)
@@ -426,39 +438,50 @@ class _ModelRegistry:
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_text_generation_model
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model
def is_embedding_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_embedding_model
def is_cross_encoder_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_cross_encoding
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_cross_encoding
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_multimodal
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal
def is_pp_supported_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_pp
def model_has_inner_state(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state
def model_has_inner_state(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_inner_state
def is_attention_free_model(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free
def is_attention_free_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_attention_free
ModelRegistry = _ModelRegistry({