diff --git a/tests/models/registry.py b/tests/models/registry.py index cf8e5032d..3927b3ac0 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -546,15 +546,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), - "HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"), - "ColBERTModernBertModel": _HfExamplesInfo( - "lightonai/GTE-ModernColBERT-v1", - hf_overrides={"architectures": ["ColBERTModernBertModel"]}, - ), - "ColBERTJinaRobertaModel": _HfExamplesInfo( - "jinaai/jina-colbert-v2", - trust_remote_code=True, - hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]}, + "BertSpladeSparseEmbeddingModel": _HfExamplesInfo( + "naver/splade-v3", + hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]}, ), "BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), @@ -568,10 +562,6 @@ _EMBEDDING_EXAMPLE_MODELS = { trust_remote_code=True, hf_overrides={"architectures": ["GteNewModel"]}, ), - "InternLM2ForRewardModel": _HfExamplesInfo( - "internlm/internlm2-1_8b-reward", trust_remote_code=True - ), - "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaBidirectionalModel": _HfExamplesInfo( "nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True @@ -584,35 +574,14 @@ _EMBEDDING_EXAMPLE_MODELS = { "nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True ), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), - "Qwen2ForRewardModel": _HfExamplesInfo( - "Qwen/Qwen2.5-Math-RM-72B", - max_transformers_version="4.53", - transformers_version_reason={ - "hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501 - }, - ), - "Qwen2ForProcessRewardModel": _HfExamplesInfo( - "Qwen/Qwen2.5-Math-PRM-7B", - max_transformers_version="4.53", - transformers_version_reason={ - "hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501 - }, - ), "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), "VoyageQwen3BidirectionalEmbedModel": _HfExamplesInfo( "voyageai/voyage-4-nano", trust_remote_code=True ), "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), - "BertSpladeSparseEmbeddingModel": _HfExamplesInfo( - "naver/splade-v3", - hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]}, - ), # [Multimodal] "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), - "ColModernVBertForRetrieval": _HfExamplesInfo( - "ModernVBERT/colmodernvbert-merged", - ), "LlamaNemotronVLModel": _HfExamplesInfo( "nvidia/llama-nemotron-embed-vl-1b-v2", trust_remote_code=True ), @@ -621,15 +590,6 @@ _EMBEDDING_EXAMPLE_MODELS = { "TIGER-Lab/VLM2Vec-Full", trust_remote_code=True ), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), - "ColQwen3": _HfExamplesInfo( - "TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True - ), - "OpsColQwen3Model": _HfExamplesInfo( - "OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True - ), - "Qwen3VLNemotronEmbedModel": _HfExamplesInfo( - "nvidia/nemotron-colembed-vl-4b-v2", - ), "SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"), "PrithviGeoSpatialMAE": _HfExamplesInfo( "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", @@ -649,21 +609,74 @@ _EMBEDDING_EXAMPLE_MODELS = { ), } -_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { - # [Decoder-only] - "GPT2ForSequenceClassification": _HfExamplesInfo( - "nie3e/sentiment-polish-gpt2-small" +_LATE_INTERACTION_EXAMPLE_MODELS = { + # [Text-only] + "HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"), + "ColBERTModernBertModel": _HfExamplesInfo( + "lightonai/GTE-ModernColBERT-v1", + hf_overrides={"architectures": ["ColBERTModernBertModel"]}, ), - # [Cross-encoder] + "ColBERTJinaRobertaModel": _HfExamplesInfo( + "jinaai/jina-colbert-v2", + trust_remote_code=True, + hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]}, + ), + # [Multimodal] + "ColModernVBertForRetrieval": _HfExamplesInfo( + "ModernVBERT/colmodernvbert-merged", + ), + "ColQwen3": _HfExamplesInfo( + "TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True + ), + "OpsColQwen3Model": _HfExamplesInfo( + "OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True + ), + "Qwen3VLNemotronEmbedModel": _HfExamplesInfo( + "nvidia/nemotron-colembed-vl-4b-v2", + ), +} + + +_REWARD_EXAMPLE_MODELS = { + "InternLM2ForRewardModel": _HfExamplesInfo( + "internlm/internlm2-1_8b-reward", trust_remote_code=True + ), + "Qwen2ForRewardModel": _HfExamplesInfo( + "Qwen/Qwen2.5-Math-RM-72B", + max_transformers_version="4.53", + transformers_version_reason={ + "hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501 + }, + ), + "Qwen2ForProcessRewardModel": _HfExamplesInfo( + "Qwen/Qwen2.5-Math-PRM-7B", + max_transformers_version="4.53", + transformers_version_reason={ + "hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501 + }, + ), +} + +_TOKEN_CLASSIFICATION_EXAMPLE_MODELS = { + "BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"), + "ModernBertForTokenClassification": _HfExamplesInfo( + "disham993/electrical-ner-ModernBERT-base" + ), +} + +_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { "BertForSequenceClassification": _HfExamplesInfo( "cross-encoder/ms-marco-MiniLM-L-6-v2" ), - "BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"), + "GPT2ForSequenceClassification": _HfExamplesInfo( + "nie3e/sentiment-polish-gpt2-small" + ), "GteNewForSequenceClassification": _HfExamplesInfo( "Alibaba-NLP/gte-multilingual-reranker-base", trust_remote_code=True, hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, ), + "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), "LlamaBidirectionalForSequenceClassification": _HfExamplesInfo( "nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True ), @@ -673,9 +686,6 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { "ModernBertForSequenceClassification": _HfExamplesInfo( "Alibaba-NLP/gte-reranker-modernbert-base" ), - "ModernBertForTokenClassification": _HfExamplesInfo( - "disham993/electrical-ner-ModernBERT-base" - ), "RobertaForSequenceClassification": _HfExamplesInfo( "cross-encoder/quora-roberta-base" ), @@ -1273,6 +1283,9 @@ _TRANSFORMERS_BACKEND_MODELS = { _EXAMPLE_MODELS = { **_TEXT_GENERATION_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS, + **_LATE_INTERACTION_EXAMPLE_MODELS, + **_REWARD_EXAMPLE_MODELS, + **_TOKEN_CLASSIFICATION_EXAMPLE_MODELS, **_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS, diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index fa273527b..81fae02ef 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -56,21 +56,24 @@ def test_registry_imports(model_arch): @create_new_process_for_each_test() @pytest.mark.parametrize( - "model_arch,is_mm,init_cuda,is_ce", + "model_arch,is_mm,init_cuda,score_type", [ - ("LlamaForCausalLM", False, False, False), - ("LlavaForConditionalGeneration", True, True, False), - ("BertForSequenceClassification", False, False, True), - ("RobertaForSequenceClassification", False, False, True), - ("XLMRobertaForSequenceClassification", False, False, True), + ("LlamaForCausalLM", False, False, "bi-encoder"), + ("LlavaForConditionalGeneration", True, True, "bi-encoder"), + ("BertForSequenceClassification", False, False, "cross-encoder"), + ("RobertaForSequenceClassification", False, False, "cross-encoder"), + ("XLMRobertaForSequenceClassification", False, False, "cross-encoder"), + ("GteNewModel", False, False, "bi-encoder"), + ("GteNewForSequenceClassification", False, False, "cross-encoder"), + ("HF_ColBERT", False, False, "late-interaction"), ], ) -def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): +def test_registry_model_property(model_arch, is_mm, init_cuda, score_type): model_info = ModelRegistry._try_inspect_model_cls(model_arch) assert model_info is not None assert model_info.supports_multimodal is is_mm - assert model_info.supports_cross_encoding is is_ce + assert model_info.score_type == score_type if init_cuda and current_platform.is_cuda_alike(): assert not torch.cuda.is_initialized() diff --git a/vllm/config/model.py b/vllm/config/model.py index 6c48bfde6..bd35e491d 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -20,6 +20,7 @@ from vllm.config.scheduler import RunnerType from vllm.config.utils import config, getattr_iter from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.tasks import ScoreType from vllm.transformers_utils.config import ( ConfigFormat, get_config, @@ -1412,16 +1413,23 @@ class ModelConfig: return self._model_info.requires_raw_input_tokens @property - def is_cross_encoder(self) -> bool: + def score_type(self) -> ScoreType: + """ + Score API handles score/rerank for: + - "score" task (score_type: cross-encoder models) + - "embed" task (score_type: bi-encoder models) + - "token_embed" task (score_type: late interaction models) + """ + # fixme: self._model_info.score_type is the score type before + # as_seq_cls_model, which is "bi-encoder", rather than the + # score type after as_seq_cls_model, which is "cross-encoder". + # Therefore, the following logic is required. return ( - self._model_info.supports_cross_encoding or self.convert_type == "classify" + "cross-encoder" + if self.convert_type == "classify" + else self._model_info.score_type ) - @property - def is_late_interaction(self) -> bool: - """Check if model uses late interaction (ColBERT-style) scoring.""" - return self._model_info.supports_late_interaction - @property def is_pp_supported(self) -> bool: return self._model_info.supports_pp diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b5fc270ff..5909b3043 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1584,8 +1584,11 @@ class LLM: ) supported_tasks = self.supported_tasks + score_type = self.model_config.score_type + is_late_interaction = score_type == "late-interaction" + is_cross_encoder = score_type == "cross-encoder" + # Late interaction models (e.g., ColBERT) use token_embed for scoring - is_late_interaction = model_config.is_late_interaction if not is_late_interaction and all( t not in supported_tasks for t in ("embed", "classify") ): @@ -1595,13 +1598,10 @@ class LLM: "`--convert embed` or `--convert classify`." ) - if ( - model_config.is_cross_encoder - and getattr(model_config.hf_config, "num_labels", 0) != 1 - ): + if is_cross_encoder and getattr(model_config.hf_config, "num_labels", 0) != 1: raise ValueError("Score API is only enabled for num_labels == 1.") - if not model_config.is_cross_encoder and chat_template is not None: + if not is_cross_encoder and chat_template is not None: raise ValueError( "chat_template is only supported for cross-encoder models." ) @@ -1622,7 +1622,7 @@ class LLM: ) encode_kwargs = tok_params.get_encode_kwargs() - if model_config.is_cross_encoder: + if is_cross_encoder: return self._cross_encoding_score( score_data_1, score_data_2, diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index 7844ed16e..f64675e56 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -37,10 +37,10 @@ def register_pooling_api_routers( app.include_router(embed_router) - # Score/rerank endpoints are available for: - # - "score" task (cross-encoder models) - # - "embed" task (bi-encoder models) - # - "token_embed" task (late interaction models like ColBERT) + # Score API handles score/rerank for: + # - "score" task (score_type: cross-encoder models) + # - "embed" task (score_type: bi-encoder models) + # - "token_embed" task (score_type: late interaction models) if any(t in supported_tasks for t in ("score", "embed", "token_embed")): from vllm.entrypoints.pooling.score.api_router import router as score_router @@ -101,10 +101,10 @@ def init_pooling_state( if "classify" in supported_tasks else None ) - # ServingScores handles score/rerank for: - # - "score" task (cross-encoder models) - # - "embed" task (bi-encoder models) - # - "token_embed" task (late interaction models like ColBERT) + # Score API handles score/rerank for: + # - "score" task (score_type: cross-encoder models) + # - "embed" task (score_type: bi-encoder models) + # - "token_embed" task (score_type: late interaction models) state.serving_scores = ( ServingScores( engine_client, diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index 546ad7698..c58fe6d36 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -69,16 +69,15 @@ class ServingScores(OpenAIServing): self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) - self.is_cross_encoder = self.model_config.is_cross_encoder - self.is_multimodal_model = self.model_config.is_multimodal_model + self.score_type = self.model_config.score_type self.architecture = self.model_config.architecture - self.is_late_interaction = self.model_config.is_late_interaction + self.is_multimodal_model = self.model_config.is_multimodal_model - if self.is_cross_encoder: + if self.score_type == "cross-encoder": self._score_func = self._cross_encoding_score - elif self.is_late_interaction: + elif self.score_type == "late-interaction": self._score_func = self._late_interaction_score - else: + else: # "bi-encoder" self._score_func = self._embedding_score async def _embedding_score( diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 7611d2d71..2209704ff 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -30,8 +30,11 @@ from vllm.lora.utils import ( replace_submodule, ) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.models import SupportsLoRA, supports_multimodal -from vllm.model_executor.models.interfaces import is_pooling_model +from vllm.model_executor.models import ( + SupportsLoRA, + is_pooling_model, + supports_multimodal, +) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer from vllm.multimodal import MULTIMODAL_REGISTRY diff --git a/vllm/model_executor/models/colbert.py b/vllm/model_executor/models/colbert.py index b876d451b..66def505f 100644 --- a/vllm/model_executor/models/colbert.py +++ b/vllm/model_executor/models/colbert.py @@ -18,7 +18,6 @@ Reference: https://arxiv.org/abs/2004.12832 """ from collections.abc import Iterable -from typing import ClassVar, Literal import torch from torch import nn @@ -28,16 +27,16 @@ from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from .bert import BertEmbeddingModel, BertModel +from .interfaces import SupportsLateInteraction from .interfaces_base import default_pooling_type -class ColBERTMixin: +class ColBERTMixin(nn.Module, SupportsLateInteraction): """Mixin that adds ColBERT late interaction support to any embedding model. ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings with a linear projection layer. This mixin provides: - - ``supports_late_interaction`` class-var - ColBERT linear projection initialisation / lazy creation - Weight loading helpers for the projection layer - A builder for the token-embedding pooler @@ -52,8 +51,6 @@ class ColBERTMixin: the ColBERT projection weight, then delegate the rest to the backbone. """ - supports_late_interaction: ClassVar[Literal[True]] = True - # Set during _init_colbert_components colbert_dim: int | None colbert_linear: nn.Linear | None diff --git a/vllm/model_executor/models/colmodernvbert.py b/vllm/model_executor/models/colmodernvbert.py index ecb243ced..39dca6edd 100644 --- a/vllm/model_executor/models/colmodernvbert.py +++ b/vllm/model_executor/models/colmodernvbert.py @@ -9,7 +9,6 @@ Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged """ from collections.abc import Iterable, Mapping, Sequence -from typing import ClassVar, Literal import torch from torch import nn @@ -37,7 +36,11 @@ from vllm.multimodal.processing import ( from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig -from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces import ( + MultiModalEmbeddings, + SupportsLateInteraction, + SupportsMultiModal, +) from .interfaces_base import default_pooling_type from .modernbert import ModernBertEmbeddings, ModernBertLayer from .siglip import SiglipVisionModel @@ -234,7 +237,9 @@ class ColModernVBertMultiModalProcessor( dummy_inputs=ColModernVBertDummyInputsBuilder, ) @default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL") -class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal): +class ColModernVBertForRetrieval( + nn.Module, SupportsMultiModal, SupportsLateInteraction +): """ColModernVBERT multimodal late-interaction retrieval model. Architecture: @@ -248,7 +253,6 @@ class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal): """ is_pooling_model = True - supports_late_interaction: ClassVar[Literal[True]] = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/colqwen3.py b/vllm/model_executor/models/colqwen3.py index 7513c01e8..1db5e0742 100644 --- a/vllm/model_executor/models/colqwen3.py +++ b/vllm/model_executor/models/colqwen3.py @@ -20,7 +20,6 @@ Target models: """ from collections.abc import Iterable, Mapping -from typing import ClassVar, Literal import torch import torch.nn as nn @@ -31,6 +30,7 @@ from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY +from .interfaces import SupportsLateInteraction from .interfaces_base import default_pooling_type from .qwen2_vl import Qwen2VLMultiModalDataParser from .qwen3_vl import ( @@ -113,9 +113,7 @@ class ColQwen3ProcessingInfo(Qwen3VLProcessingInfo): info=ColQwen3ProcessingInfo, dummy_inputs=Qwen3VLDummyInputsBuilder, ) -class ColQwen3Model( - Qwen3VLForConditionalGeneration, -): +class ColQwen3Model(Qwen3VLForConditionalGeneration, SupportsLateInteraction): """ColQwen3 late interaction model for multi-modal retrieval/reranking. This model extends Qwen3VLForConditionalGeneration with a ColBERT-style @@ -132,16 +130,11 @@ class ColQwen3Model( Attributes: custom_text_proj: Linear projection from hidden_size to embed_dim - supports_late_interaction: Flag indicating this model uses late - interaction scoring """ # Mark this as a pooling model so vLLM routes to pooler path is_pooling_model = True - # Mark this model as supporting late interaction scoring - supports_late_interaction: ClassVar[Literal[True]] = True - # Override hf_to_vllm_mapper to handle ColQwen3 weight naming. # NOTE: WeightsMapper applies ALL matching prefix rules sequentially # (no early exit), so more-specific prefixes must come first. diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 3e90578f8..ac35b3157 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -34,10 +34,11 @@ from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.tasks import ScoreType from vllm.utils.collection_utils import common_prefix from vllm.utils.func_utils import supports_kw -from .interfaces_base import VllmModel, is_pooling_model +from .interfaces_base import VllmModel if TYPE_CHECKING: from vllm.config import VllmConfig @@ -969,29 +970,7 @@ def supports_mamba_prefix_caching( class SupportsCrossEncoding(Protocol): """The interface required for all models that support cross encoding.""" - supports_cross_encoding: ClassVar[Literal[True]] = True - - -@overload -def supports_cross_encoding( - model: type[object], -) -> TypeIs[type[SupportsCrossEncoding]]: ... - - -@overload -def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ... - - -def _supports_cross_encoding( - model: type[object] | object, -) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]: - return getattr(model, "supports_cross_encoding", False) - - -def supports_cross_encoding( - model: type[object] | object, -) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]: - return is_pooling_model(model) and _supports_cross_encoding(model) + score_type: ClassVar[ScoreType] = "cross-encoder" @runtime_checkable @@ -1003,29 +982,7 @@ class SupportsLateInteraction(Protocol): MaxSim (max over document tokens, sum over query tokens). """ - supports_late_interaction: ClassVar[Literal[True]] = True - - -@overload -def supports_late_interaction( - model: type[object], -) -> TypeIs[type[SupportsLateInteraction]]: ... - - -@overload -def supports_late_interaction(model: object) -> TypeIs[SupportsLateInteraction]: ... - - -def _supports_late_interaction( - model: type[object] | object, -) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]: - return getattr(model, "supports_late_interaction", False) - - -def supports_late_interaction( - model: type[object] | object, -) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]: - return is_pooling_model(model) and _supports_late_interaction(model) + score_type: ClassVar[ScoreType] = "late-interaction" class SupportsQuant: diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index e658825e1..55c42e5fa 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -15,6 +15,7 @@ import torch.nn as nn from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger +from vllm.tasks import ScoreType from vllm.utils.func_utils import supports_kw if TYPE_CHECKING: @@ -187,6 +188,26 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): decorator to conveniently set this field. """ + score_type: ClassVar[ScoreType] = "bi-encoder" + """ + Indicates the + [vllm.config.model.ModelConfig.score_type][] + to use by default. + + Score API handles score/rerank for: + - "score" task (score_type: cross-encoder models) + - "embed" task (score_type: bi-encoder models) + - "token_embed" task (score_type: late interaction models) + + score_type defaults to bi-encoder, then the Score API uses the "embed" task. + If you set score_type to cross-encoder via + [vllm.model_executor.models.interfaces.SupportsCrossEncoding][], + then the Score API uses the "score" task. + If you set score_type to late-interaction via + [vllm.model_executor.models.interfaces.SupportsLateInteraction][], + then the Score API uses the "token_embed" task. + """ + pooler: Pooler """The pooler is only called on TP rank 0.""" @@ -250,3 +271,13 @@ def attn_type(attn_type: AttnTypeStr): def get_attn_type(model: type[object] | object) -> AttnTypeStr: return getattr(model, "attn_type", "decoder") + + +def get_score_type(model: type[object] | object) -> ScoreType: + score_types = set() + for m in model.__mro__: + score_type = getattr(m, "score_type", "bi-encoder") + if score_type != "bi-encoder": + score_types.add(score_type) + assert len(score_types) < 2 + return "bi-encoder" if not score_types else list(score_types)[0] diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 46437adf4..34dda9b38 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -30,6 +30,7 @@ from vllm.config import ( ) from vllm.logger import init_logger from vllm.logging_utils import logtime +from vllm.tasks import ScoreType from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module from vllm.utils.hashing import safe_hash @@ -48,8 +49,6 @@ from .interfaces import ( is_attention_free, is_hybrid, requires_raw_input_tokens, - supports_cross_encoding, - supports_late_interaction, supports_mamba_prefix_caching, supports_multimodal, supports_multimodal_encoder_tp_data, @@ -61,6 +60,7 @@ from .interfaces_base import ( get_attn_type, get_default_seq_pooling_type, get_default_tok_pooling_type, + get_score_type, is_pooling_model, is_text_generation_model, ) @@ -214,19 +214,14 @@ _EMBEDDING_MODELS = { # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"), - "HF_ColBERT": ("colbert", "ColBERTModel"), - "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"), - "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"), + "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma3TextModel": ("gemma3", "Gemma3Model"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), - "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GritLM": ("gritlm", "GritLM"), "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), "GteNewModel": ("bert_with_rope", "GteNewModel"), - "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), - "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"), "LlamaModel": ("llama", "LlamaForCausalLM"), **{ @@ -241,8 +236,6 @@ _EMBEDDING_MODELS = { "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), - "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), - "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "RobertaModel": ("roberta", "RobertaEmbeddingModel"), "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), @@ -252,19 +245,14 @@ _EMBEDDING_MODELS = { "VoyageQwen3BidirectionalEmbedModel", ), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), - "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), # [Multimodal] "CLIPModel": ("clip", "CLIPEmbeddingModel"), - "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"), "LlavaNextForConditionalGeneration": ( "llava_next", "LlavaNextForConditionalGeneration", ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - "ColQwen3": ("colqwen3", "ColQwen3Model"), - "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"), - "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"), "SiglipModel": ("siglip", "SiglipEmbeddingModel"), "LlamaNemotronVLModel": ( "nemotron_vl", @@ -277,35 +265,59 @@ _EMBEDDING_MODELS = { "Terratorch": ("terratorch", "Terratorch"), } -_CROSS_ENCODER_MODELS = { - "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), +_LATE_INTERACTION_MODELS = { + # [Text-only] + "HF_ColBERT": ("colbert", "ColBERTModel"), + "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"), + "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"), + # [Multimodal] + "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"), + "ColQwen3": ("colqwen3", "ColQwen3Model"), + "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"), + "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"), +} + +_REWARD_MODELS = { + "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), + "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), + "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), +} + +_TOKEN_CLASSIFICATION_MODELS = { "BertForTokenClassification": ("bert", "BertForTokenClassification"), + "ModernBertForTokenClassification": ( + "modernbert", + "ModernBertForTokenClassification", + ), +} + +_SEQUENCE_CLASSIFICATION_MODELS = { + "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GteNewForSequenceClassification": ( "bert_with_rope", "GteNewForSequenceClassification", ), - "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), + "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "LlamaBidirectionalForSequenceClassification": ( "llama", "LlamaBidirectionalForSequenceClassification", ), - "LlamaNemotronVLForSequenceClassification": ( - "nemotron_vl", - "LlamaNemotronVLForSequenceClassification", - ), "ModernBertForSequenceClassification": ( "modernbert", "ModernBertForSequenceClassification", ), - "ModernBertForTokenClassification": ( - "modernbert", - "ModernBertForTokenClassification", - ), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), "XLMRobertaForSequenceClassification": ( "roberta", "RobertaForSequenceClassification", ), + # [Multimodal] + "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), + "LlamaNemotronVLForSequenceClassification": ( + "nemotron_vl", + "LlamaNemotronVLForSequenceClassification", + ), } _MULTIMODAL_MODELS = { @@ -606,7 +618,10 @@ _TRANSFORMERS_BACKEND_MODELS = { _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, - **_CROSS_ENCODER_MODELS, + **_LATE_INTERACTION_MODELS, + **_REWARD_MODELS, + **_TOKEN_CLASSIFICATION_MODELS, + **_SEQUENCE_CLASSIFICATION_MODELS, **_MULTIMODAL_MODELS, **_SPECULATIVE_DECODING_MODELS, **_TRANSFORMERS_SUPPORTED_MODELS, @@ -643,8 +658,7 @@ class _ModelInfo: attn_type: AttnTypeStr default_seq_pooling_type: SequencePoolingType default_tok_pooling_type: TokenPoolingType - supports_cross_encoding: bool - supports_late_interaction: bool + score_type: ScoreType supports_multimodal: bool supports_multimodal_raw_input_only: bool requires_raw_input_tokens: bool @@ -667,8 +681,7 @@ class _ModelInfo: default_seq_pooling_type=get_default_seq_pooling_type(model), default_tok_pooling_type=get_default_tok_pooling_type(model), attn_type=get_attn_type(model), - supports_cross_encoding=supports_cross_encoding(model), - supports_late_interaction=supports_late_interaction(model), + score_type=get_score_type(model), supports_multimodal=supports_multimodal(model), supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( model @@ -1166,14 +1179,6 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures, model_config) return model_cls.is_pooling_model - def is_cross_encoder_model( - self, - architectures: str | list[str], - model_config: ModelConfig, - ) -> bool: - model_cls, _ = self.inspect_model_cls(architectures, model_config) - return model_cls.supports_cross_encoding - def is_multimodal_model( self, architectures: str | list[str], diff --git a/vllm/tasks.py b/vllm/tasks.py index 3a64e462e..950993279 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -10,6 +10,12 @@ PoolingTask = Literal[ ] POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask) +# Score API handles score/rerank for: +# - "score" task (score_type: cross-encoder models) +# - "embed" task (score_type: bi-encoder models) +# - "token_embed" task (score_type: late interaction models) +ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"] + FrontendTask = Literal["render"] FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask)