Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
committed by
GitHub
parent
49628fe13e
commit
214efc2c3c
@@ -21,7 +21,8 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_multimodal, supports_pp)
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -100,6 +101,7 @@ _EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
@@ -121,6 +123,14 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||
"RobertaForSequenceClassification": ("roberta",
|
||||
"RobertaForSequenceClassification"),
|
||||
"XLMRobertaForSequenceClassification": ("roberta",
|
||||
"RobertaForSequenceClassification"),
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
# [Decoder-only]
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
@@ -159,6 +169,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
_VLLM_MODELS = {
|
||||
**_TEXT_GENERATION_MODELS,
|
||||
**_EMBEDDING_MODELS,
|
||||
**_CROSS_ENCODER_MODELS,
|
||||
**_MULTIMODAL_MODELS,
|
||||
**_SPECULATIVE_DECODING_MODELS,
|
||||
}
|
||||
@@ -193,6 +204,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
class _ModelInfo:
|
||||
is_text_generation_model: bool
|
||||
is_embedding_model: bool
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
@@ -203,6 +215,7 @@ class _ModelInfo:
|
||||
return _ModelInfo(
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_embedding_model=is_embedding_model(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
@@ -415,6 +428,12 @@ class _ModelRegistry:
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_embedding_model
|
||||
|
||||
def is_cross_encoder_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_cross_encoding
|
||||
|
||||
def is_multimodal_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
@@ -489,4 +508,4 @@ def _run() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_run()
|
||||
_run()
|
||||
|
||||
Reference in New Issue
Block a user