[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -20,6 +20,7 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .adapters import as_embedding_model
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
@@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
|
||||
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
**{
|
||||
# Multiple models share the same architecture, so we include them all
|
||||
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
|
||||
if arch == "LlamaForCausalLM"
|
||||
},
|
||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"MistralModel": ("llama", "LlamaForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
@@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
@@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ModelInfo:
|
||||
architecture: str
|
||||
is_text_generation_model: bool
|
||||
is_embedding_model: bool
|
||||
supports_cross_encoding: bool
|
||||
@@ -218,9 +220,19 @@ class _ModelInfo:
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
is_embedding_model_ = is_embedding_model(model)
|
||||
if not is_embedding_model_:
|
||||
try:
|
||||
as_embedding_model(model)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
is_embedding_model_ = True
|
||||
|
||||
return _ModelInfo(
|
||||
architecture=model.__name__,
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_embedding_model=is_embedding_model(model),
|
||||
is_embedding_model=is_embedding_model_,
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
@@ -399,13 +411,13 @@ class _ModelRegistry:
|
||||
def inspect_model_cls(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> _ModelInfo:
|
||||
) -> Tuple[_ModelInfo, str]:
|
||||
architectures = self._normalize_archs(architectures)
|
||||
|
||||
for arch in architectures:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return model_info
|
||||
return (model_info, arch)
|
||||
|
||||
return self._raise_for_unsupported(architectures)
|
||||
|
||||
@@ -426,39 +438,50 @@ class _ModelRegistry:
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_text_generation_model
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_text_generation_model
|
||||
|
||||
def is_embedding_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_embedding_model
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_embedding_model
|
||||
|
||||
def is_cross_encoder_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_cross_encoding
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_cross_encoding
|
||||
|
||||
def is_multimodal_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_multimodal
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_multimodal
|
||||
|
||||
def is_pp_supported_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_pp
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_pp
|
||||
|
||||
def model_has_inner_state(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).has_inner_state
|
||||
def model_has_inner_state(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.has_inner_state
|
||||
|
||||
def is_attention_free_model(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_attention_free
|
||||
def is_attention_free_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_attention_free
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry({
|
||||
|
||||
Reference in New Issue
Block a user