Support Cross encoder models (#10400)

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Maximilien de Bayser
2024-11-24 23:56:20 -03:00
committed by GitHub
parent 49628fe13e
commit 214efc2c3c
28 changed files with 1370 additions and 62 deletions

View File

@@ -21,7 +21,8 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp)
supports_cross_encoding, supports_multimodal,
supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
@@ -100,6 +101,7 @@ _EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
@@ -121,6 +123,14 @@ _EMBEDDING_MODELS = {
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
}
_CROSS_ENCODER_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
"RobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"),
}
_MULTIMODAL_MODELS = {
# [Decoder-only]
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
@@ -159,6 +169,7 @@ _SPECULATIVE_DECODING_MODELS = {
_VLLM_MODELS = {
**_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_CROSS_ENCODER_MODELS,
**_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS,
}
@@ -193,6 +204,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class _ModelInfo:
is_text_generation_model: bool
is_embedding_model: bool
supports_cross_encoding: bool
supports_multimodal: bool
supports_pp: bool
has_inner_state: bool
@@ -203,6 +215,7 @@ class _ModelInfo:
return _ModelInfo(
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
@@ -415,6 +428,12 @@ class _ModelRegistry:
) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model
def is_cross_encoder_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_cross_encoding
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
@@ -489,4 +508,4 @@ def _run() -> None:
if __name__ == "__main__":
_run()
_run()