[Model] Consolidate score logic by introduce score_type (#36479)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -546,15 +546,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
|
||||
"ColBERTModernBertModel": _HfExamplesInfo(
|
||||
"lightonai/GTE-ModernColBERT-v1",
|
||||
hf_overrides={"architectures": ["ColBERTModernBertModel"]},
|
||||
),
|
||||
"ColBERTJinaRobertaModel": _HfExamplesInfo(
|
||||
"jinaai/jina-colbert-v2",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]},
|
||||
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
|
||||
"naver/splade-v3",
|
||||
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
|
||||
),
|
||||
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||
@@ -568,10 +562,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GteNewModel"]},
|
||||
),
|
||||
"InternLM2ForRewardModel": _HfExamplesInfo(
|
||||
"internlm/internlm2-1_8b-reward", trust_remote_code=True
|
||||
),
|
||||
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"),
|
||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||
"LlamaBidirectionalModel": _HfExamplesInfo(
|
||||
"nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True
|
||||
@@ -584,35 +574,14 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True
|
||||
),
|
||||
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo(
|
||||
"Qwen/Qwen2.5-Math-RM-72B",
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason={
|
||||
"hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501
|
||||
},
|
||||
),
|
||||
"Qwen2ForProcessRewardModel": _HfExamplesInfo(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason={
|
||||
"hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501
|
||||
},
|
||||
),
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"),
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
|
||||
"VoyageQwen3BidirectionalEmbedModel": _HfExamplesInfo(
|
||||
"voyageai/voyage-4-nano", trust_remote_code=True
|
||||
),
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
|
||||
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
|
||||
"naver/splade-v3",
|
||||
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
|
||||
),
|
||||
# [Multimodal]
|
||||
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
|
||||
"ColModernVBertForRetrieval": _HfExamplesInfo(
|
||||
"ModernVBERT/colmodernvbert-merged",
|
||||
),
|
||||
"LlamaNemotronVLModel": _HfExamplesInfo(
|
||||
"nvidia/llama-nemotron-embed-vl-1b-v2", trust_remote_code=True
|
||||
),
|
||||
@@ -621,15 +590,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"TIGER-Lab/VLM2Vec-Full", trust_remote_code=True
|
||||
),
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"),
|
||||
"ColQwen3": _HfExamplesInfo(
|
||||
"TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True
|
||||
),
|
||||
"OpsColQwen3Model": _HfExamplesInfo(
|
||||
"OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True
|
||||
),
|
||||
"Qwen3VLNemotronEmbedModel": _HfExamplesInfo(
|
||||
"nvidia/nemotron-colembed-vl-4b-v2",
|
||||
),
|
||||
"SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"),
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo(
|
||||
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
@@ -649,21 +609,74 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
),
|
||||
}
|
||||
|
||||
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"GPT2ForSequenceClassification": _HfExamplesInfo(
|
||||
"nie3e/sentiment-polish-gpt2-small"
|
||||
_LATE_INTERACTION_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
|
||||
"ColBERTModernBertModel": _HfExamplesInfo(
|
||||
"lightonai/GTE-ModernColBERT-v1",
|
||||
hf_overrides={"architectures": ["ColBERTModernBertModel"]},
|
||||
),
|
||||
# [Cross-encoder]
|
||||
"ColBERTJinaRobertaModel": _HfExamplesInfo(
|
||||
"jinaai/jina-colbert-v2",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]},
|
||||
),
|
||||
# [Multimodal]
|
||||
"ColModernVBertForRetrieval": _HfExamplesInfo(
|
||||
"ModernVBERT/colmodernvbert-merged",
|
||||
),
|
||||
"ColQwen3": _HfExamplesInfo(
|
||||
"TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True
|
||||
),
|
||||
"OpsColQwen3Model": _HfExamplesInfo(
|
||||
"OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True
|
||||
),
|
||||
"Qwen3VLNemotronEmbedModel": _HfExamplesInfo(
|
||||
"nvidia/nemotron-colembed-vl-4b-v2",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
_REWARD_EXAMPLE_MODELS = {
|
||||
"InternLM2ForRewardModel": _HfExamplesInfo(
|
||||
"internlm/internlm2-1_8b-reward", trust_remote_code=True
|
||||
),
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo(
|
||||
"Qwen/Qwen2.5-Math-RM-72B",
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason={
|
||||
"hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501
|
||||
},
|
||||
),
|
||||
"Qwen2ForProcessRewardModel": _HfExamplesInfo(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason={
|
||||
"hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
_TOKEN_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
|
||||
"ModernBertForTokenClassification": _HfExamplesInfo(
|
||||
"disham993/electrical-ner-ModernBERT-base"
|
||||
),
|
||||
}
|
||||
|
||||
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
"BertForSequenceClassification": _HfExamplesInfo(
|
||||
"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
),
|
||||
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
|
||||
"GPT2ForSequenceClassification": _HfExamplesInfo(
|
||||
"nie3e/sentiment-polish-gpt2-small"
|
||||
),
|
||||
"GteNewForSequenceClassification": _HfExamplesInfo(
|
||||
"Alibaba-NLP/gte-multilingual-reranker-base",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
|
||||
),
|
||||
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"),
|
||||
"LlamaBidirectionalForSequenceClassification": _HfExamplesInfo(
|
||||
"nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True
|
||||
),
|
||||
@@ -673,9 +686,6 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo(
|
||||
"Alibaba-NLP/gte-reranker-modernbert-base"
|
||||
),
|
||||
"ModernBertForTokenClassification": _HfExamplesInfo(
|
||||
"disham993/electrical-ner-ModernBERT-base"
|
||||
),
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo(
|
||||
"cross-encoder/quora-roberta-base"
|
||||
),
|
||||
@@ -1273,6 +1283,9 @@ _TRANSFORMERS_BACKEND_MODELS = {
|
||||
_EXAMPLE_MODELS = {
|
||||
**_TEXT_GENERATION_EXAMPLE_MODELS,
|
||||
**_EMBEDDING_EXAMPLE_MODELS,
|
||||
**_LATE_INTERACTION_EXAMPLE_MODELS,
|
||||
**_REWARD_EXAMPLE_MODELS,
|
||||
**_TOKEN_CLASSIFICATION_EXAMPLE_MODELS,
|
||||
**_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS,
|
||||
**_MULTIMODAL_EXAMPLE_MODELS,
|
||||
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
|
||||
|
||||
@@ -56,21 +56,24 @@ def test_registry_imports(model_arch):
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize(
|
||||
"model_arch,is_mm,init_cuda,is_ce",
|
||||
"model_arch,is_mm,init_cuda,score_type",
|
||||
[
|
||||
("LlamaForCausalLM", False, False, False),
|
||||
("LlavaForConditionalGeneration", True, True, False),
|
||||
("BertForSequenceClassification", False, False, True),
|
||||
("RobertaForSequenceClassification", False, False, True),
|
||||
("XLMRobertaForSequenceClassification", False, False, True),
|
||||
("LlamaForCausalLM", False, False, "bi-encoder"),
|
||||
("LlavaForConditionalGeneration", True, True, "bi-encoder"),
|
||||
("BertForSequenceClassification", False, False, "cross-encoder"),
|
||||
("RobertaForSequenceClassification", False, False, "cross-encoder"),
|
||||
("XLMRobertaForSequenceClassification", False, False, "cross-encoder"),
|
||||
("GteNewModel", False, False, "bi-encoder"),
|
||||
("GteNewForSequenceClassification", False, False, "cross-encoder"),
|
||||
("HF_ColBERT", False, False, "late-interaction"),
|
||||
],
|
||||
)
|
||||
def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
|
||||
def test_registry_model_property(model_arch, is_mm, init_cuda, score_type):
|
||||
model_info = ModelRegistry._try_inspect_model_cls(model_arch)
|
||||
assert model_info is not None
|
||||
|
||||
assert model_info.supports_multimodal is is_mm
|
||||
assert model_info.supports_cross_encoding is is_ce
|
||||
assert model_info.score_type == score_type
|
||||
|
||||
if init_cuda and current_platform.is_cuda_alike():
|
||||
assert not torch.cuda.is_initialized()
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import config, getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.tasks import ScoreType
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat,
|
||||
get_config,
|
||||
@@ -1412,16 +1413,23 @@ class ModelConfig:
|
||||
return self._model_info.requires_raw_input_tokens
|
||||
|
||||
@property
|
||||
def is_cross_encoder(self) -> bool:
|
||||
def score_type(self) -> ScoreType:
|
||||
"""
|
||||
Score API handles score/rerank for:
|
||||
- "score" task (score_type: cross-encoder models)
|
||||
- "embed" task (score_type: bi-encoder models)
|
||||
- "token_embed" task (score_type: late interaction models)
|
||||
"""
|
||||
# fixme: self._model_info.score_type is the score type before
|
||||
# as_seq_cls_model, which is "bi-encoder", rather than the
|
||||
# score type after as_seq_cls_model, which is "cross-encoder".
|
||||
# Therefore, the following logic is required.
|
||||
return (
|
||||
self._model_info.supports_cross_encoding or self.convert_type == "classify"
|
||||
"cross-encoder"
|
||||
if self.convert_type == "classify"
|
||||
else self._model_info.score_type
|
||||
)
|
||||
|
||||
@property
|
||||
def is_late_interaction(self) -> bool:
|
||||
"""Check if model uses late interaction (ColBERT-style) scoring."""
|
||||
return self._model_info.supports_late_interaction
|
||||
|
||||
@property
|
||||
def is_pp_supported(self) -> bool:
|
||||
return self._model_info.supports_pp
|
||||
|
||||
@@ -1584,8 +1584,11 @@ class LLM:
|
||||
)
|
||||
|
||||
supported_tasks = self.supported_tasks
|
||||
score_type = self.model_config.score_type
|
||||
is_late_interaction = score_type == "late-interaction"
|
||||
is_cross_encoder = score_type == "cross-encoder"
|
||||
|
||||
# Late interaction models (e.g., ColBERT) use token_embed for scoring
|
||||
is_late_interaction = model_config.is_late_interaction
|
||||
if not is_late_interaction and all(
|
||||
t not in supported_tasks for t in ("embed", "classify")
|
||||
):
|
||||
@@ -1595,13 +1598,10 @@ class LLM:
|
||||
"`--convert embed` or `--convert classify`."
|
||||
)
|
||||
|
||||
if (
|
||||
model_config.is_cross_encoder
|
||||
and getattr(model_config.hf_config, "num_labels", 0) != 1
|
||||
):
|
||||
if is_cross_encoder and getattr(model_config.hf_config, "num_labels", 0) != 1:
|
||||
raise ValueError("Score API is only enabled for num_labels == 1.")
|
||||
|
||||
if not model_config.is_cross_encoder and chat_template is not None:
|
||||
if not is_cross_encoder and chat_template is not None:
|
||||
raise ValueError(
|
||||
"chat_template is only supported for cross-encoder models."
|
||||
)
|
||||
@@ -1622,7 +1622,7 @@ class LLM:
|
||||
)
|
||||
encode_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
if model_config.is_cross_encoder:
|
||||
if is_cross_encoder:
|
||||
return self._cross_encoding_score(
|
||||
score_data_1,
|
||||
score_data_2,
|
||||
|
||||
@@ -37,10 +37,10 @@ def register_pooling_api_routers(
|
||||
|
||||
app.include_router(embed_router)
|
||||
|
||||
# Score/rerank endpoints are available for:
|
||||
# - "score" task (cross-encoder models)
|
||||
# - "embed" task (bi-encoder models)
|
||||
# - "token_embed" task (late interaction models like ColBERT)
|
||||
# Score API handles score/rerank for:
|
||||
# - "score" task (score_type: cross-encoder models)
|
||||
# - "embed" task (score_type: bi-encoder models)
|
||||
# - "token_embed" task (score_type: late interaction models)
|
||||
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
|
||||
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||
|
||||
@@ -101,10 +101,10 @@ def init_pooling_state(
|
||||
if "classify" in supported_tasks
|
||||
else None
|
||||
)
|
||||
# ServingScores handles score/rerank for:
|
||||
# - "score" task (cross-encoder models)
|
||||
# - "embed" task (bi-encoder models)
|
||||
# - "token_embed" task (late interaction models like ColBERT)
|
||||
# Score API handles score/rerank for:
|
||||
# - "score" task (score_type: cross-encoder models)
|
||||
# - "embed" task (score_type: bi-encoder models)
|
||||
# - "token_embed" task (score_type: late interaction models)
|
||||
state.serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
|
||||
@@ -69,16 +69,15 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.is_cross_encoder = self.model_config.is_cross_encoder
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.score_type = self.model_config.score_type
|
||||
self.architecture = self.model_config.architecture
|
||||
self.is_late_interaction = self.model_config.is_late_interaction
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
|
||||
if self.is_cross_encoder:
|
||||
if self.score_type == "cross-encoder":
|
||||
self._score_func = self._cross_encoding_score
|
||||
elif self.is_late_interaction:
|
||||
elif self.score_type == "late-interaction":
|
||||
self._score_func = self._late_interaction_score
|
||||
else:
|
||||
else: # "bi-encoder"
|
||||
self._score_func = self._embedding_score
|
||||
|
||||
async def _embedding_score(
|
||||
|
||||
@@ -30,8 +30,11 @@ from vllm.lora.utils import (
|
||||
replace_submodule,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||
from vllm.model_executor.models import (
|
||||
SupportsLoRA,
|
||||
is_pooling_model,
|
||||
supports_multimodal,
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
@@ -18,7 +18,6 @@ Reference: https://arxiv.org/abs/2004.12832
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -28,16 +27,16 @@ from vllm.model_executor.layers.pooler import Pooler
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
||||
|
||||
from .bert import BertEmbeddingModel, BertModel
|
||||
from .interfaces import SupportsLateInteraction
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
class ColBERTMixin:
|
||||
class ColBERTMixin(nn.Module, SupportsLateInteraction):
|
||||
"""Mixin that adds ColBERT late interaction support to any embedding model.
|
||||
|
||||
ColBERT (Contextualized Late Interaction over BERT) uses per-token
|
||||
embeddings with a linear projection layer. This mixin provides:
|
||||
|
||||
- ``supports_late_interaction`` class-var
|
||||
- ColBERT linear projection initialisation / lazy creation
|
||||
- Weight loading helpers for the projection layer
|
||||
- A builder for the token-embedding pooler
|
||||
@@ -52,8 +51,6 @@ class ColBERTMixin:
|
||||
the ColBERT projection weight, then delegate the rest to the backbone.
|
||||
"""
|
||||
|
||||
supports_late_interaction: ClassVar[Literal[True]] = True
|
||||
|
||||
# Set during _init_colbert_components
|
||||
colbert_dim: int | None
|
||||
colbert_linear: nn.Linear | None
|
||||
|
||||
@@ -9,7 +9,6 @@ Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -37,7 +36,11 @@ from vllm.multimodal.processing import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLateInteraction,
|
||||
SupportsMultiModal,
|
||||
)
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .modernbert import ModernBertEmbeddings, ModernBertLayer
|
||||
from .siglip import SiglipVisionModel
|
||||
@@ -234,7 +237,9 @@ class ColModernVBertMultiModalProcessor(
|
||||
dummy_inputs=ColModernVBertDummyInputsBuilder,
|
||||
)
|
||||
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
||||
class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal):
|
||||
class ColModernVBertForRetrieval(
|
||||
nn.Module, SupportsMultiModal, SupportsLateInteraction
|
||||
):
|
||||
"""ColModernVBERT multimodal late-interaction retrieval model.
|
||||
|
||||
Architecture:
|
||||
@@ -248,7 +253,6 @@ class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal):
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
supports_late_interaction: ClassVar[Literal[True]] = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -20,7 +20,6 @@ Target models:
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -31,6 +30,7 @@ from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from .interfaces import SupportsLateInteraction
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .qwen2_vl import Qwen2VLMultiModalDataParser
|
||||
from .qwen3_vl import (
|
||||
@@ -113,9 +113,7 @@ class ColQwen3ProcessingInfo(Qwen3VLProcessingInfo):
|
||||
info=ColQwen3ProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder,
|
||||
)
|
||||
class ColQwen3Model(
|
||||
Qwen3VLForConditionalGeneration,
|
||||
):
|
||||
class ColQwen3Model(Qwen3VLForConditionalGeneration, SupportsLateInteraction):
|
||||
"""ColQwen3 late interaction model for multi-modal retrieval/reranking.
|
||||
|
||||
This model extends Qwen3VLForConditionalGeneration with a ColBERT-style
|
||||
@@ -132,16 +130,11 @@ class ColQwen3Model(
|
||||
|
||||
Attributes:
|
||||
custom_text_proj: Linear projection from hidden_size to embed_dim
|
||||
supports_late_interaction: Flag indicating this model uses late
|
||||
interaction scoring
|
||||
"""
|
||||
|
||||
# Mark this as a pooling model so vLLM routes to pooler path
|
||||
is_pooling_model = True
|
||||
|
||||
# Mark this model as supporting late interaction scoring
|
||||
supports_late_interaction: ClassVar[Literal[True]] = True
|
||||
|
||||
# Override hf_to_vllm_mapper to handle ColQwen3 weight naming.
|
||||
# NOTE: WeightsMapper applies ALL matching prefix rules sequentially
|
||||
# (no early exit), so more-specific prefixes must come first.
|
||||
|
||||
@@ -34,10 +34,11 @@ from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.tasks import ScoreType
|
||||
from vllm.utils.collection_utils import common_prefix
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
from .interfaces_base import VllmModel, is_pooling_model
|
||||
from .interfaces_base import VllmModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@@ -969,29 +970,7 @@ def supports_mamba_prefix_caching(
|
||||
class SupportsCrossEncoding(Protocol):
|
||||
"""The interface required for all models that support cross encoding."""
|
||||
|
||||
supports_cross_encoding: ClassVar[Literal[True]] = True
|
||||
|
||||
|
||||
@overload
|
||||
def supports_cross_encoding(
|
||||
model: type[object],
|
||||
) -> TypeIs[type[SupportsCrossEncoding]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ...
|
||||
|
||||
|
||||
def _supports_cross_encoding(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]:
|
||||
return getattr(model, "supports_cross_encoding", False)
|
||||
|
||||
|
||||
def supports_cross_encoding(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]:
|
||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||
score_type: ClassVar[ScoreType] = "cross-encoder"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@@ -1003,29 +982,7 @@ class SupportsLateInteraction(Protocol):
|
||||
MaxSim (max over document tokens, sum over query tokens).
|
||||
"""
|
||||
|
||||
supports_late_interaction: ClassVar[Literal[True]] = True
|
||||
|
||||
|
||||
@overload
|
||||
def supports_late_interaction(
|
||||
model: type[object],
|
||||
) -> TypeIs[type[SupportsLateInteraction]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_late_interaction(model: object) -> TypeIs[SupportsLateInteraction]: ...
|
||||
|
||||
|
||||
def _supports_late_interaction(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
|
||||
return getattr(model, "supports_late_interaction", False)
|
||||
|
||||
|
||||
def supports_late_interaction(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
|
||||
return is_pooling_model(model) and _supports_late_interaction(model)
|
||||
score_type: ClassVar[ScoreType] = "late-interaction"
|
||||
|
||||
|
||||
class SupportsQuant:
|
||||
|
||||
@@ -15,6 +15,7 @@ import torch.nn as nn
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tasks import ScoreType
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -187,6 +188,26 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
decorator to conveniently set this field.
|
||||
"""
|
||||
|
||||
score_type: ClassVar[ScoreType] = "bi-encoder"
|
||||
"""
|
||||
Indicates the
|
||||
[vllm.config.model.ModelConfig.score_type][]
|
||||
to use by default.
|
||||
|
||||
Score API handles score/rerank for:
|
||||
- "score" task (score_type: cross-encoder models)
|
||||
- "embed" task (score_type: bi-encoder models)
|
||||
- "token_embed" task (score_type: late interaction models)
|
||||
|
||||
score_type defaults to bi-encoder, then the Score API uses the "embed" task.
|
||||
If you set score_type to cross-encoder via
|
||||
[vllm.model_executor.models.interfaces.SupportsCrossEncoding][],
|
||||
then the Score API uses the "score" task.
|
||||
If you set score_type to late-interaction via
|
||||
[vllm.model_executor.models.interfaces.SupportsLateInteraction][],
|
||||
then the Score API uses the "token_embed" task.
|
||||
"""
|
||||
|
||||
pooler: Pooler
|
||||
"""The pooler is only called on TP rank 0."""
|
||||
|
||||
@@ -250,3 +271,13 @@ def attn_type(attn_type: AttnTypeStr):
|
||||
|
||||
def get_attn_type(model: type[object] | object) -> AttnTypeStr:
|
||||
return getattr(model, "attn_type", "decoder")
|
||||
|
||||
|
||||
def get_score_type(model: type[object] | object) -> ScoreType:
|
||||
score_types = set()
|
||||
for m in model.__mro__:
|
||||
score_type = getattr(m, "score_type", "bi-encoder")
|
||||
if score_type != "bi-encoder":
|
||||
score_types.add(score_type)
|
||||
assert len(score_types) < 2
|
||||
return "bi-encoder" if not score_types else list(score_types)[0]
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import logtime
|
||||
from vllm.tasks import ScoreType
|
||||
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
@@ -48,8 +49,6 @@ from .interfaces import (
|
||||
is_attention_free,
|
||||
is_hybrid,
|
||||
requires_raw_input_tokens,
|
||||
supports_cross_encoding,
|
||||
supports_late_interaction,
|
||||
supports_mamba_prefix_caching,
|
||||
supports_multimodal,
|
||||
supports_multimodal_encoder_tp_data,
|
||||
@@ -61,6 +60,7 @@ from .interfaces_base import (
|
||||
get_attn_type,
|
||||
get_default_seq_pooling_type,
|
||||
get_default_tok_pooling_type,
|
||||
get_score_type,
|
||||
is_pooling_model,
|
||||
is_text_generation_model,
|
||||
)
|
||||
@@ -214,19 +214,14 @@ _EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
|
||||
"HF_ColBERT": ("colbert", "ColBERTModel"),
|
||||
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
|
||||
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
|
||||
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
||||
"GritLM": ("gritlm", "GritLM"),
|
||||
"GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
|
||||
"GteNewModel": ("bert_with_rope", "GteNewModel"),
|
||||
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
**{
|
||||
@@ -241,8 +236,6 @@ _EMBEDDING_MODELS = {
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
|
||||
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
||||
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
@@ -252,19 +245,14 @@ _EMBEDDING_MODELS = {
|
||||
"VoyageQwen3BidirectionalEmbedModel",
|
||||
),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
|
||||
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
|
||||
"LlavaNextForConditionalGeneration": (
|
||||
"llava_next",
|
||||
"LlavaNextForConditionalGeneration",
|
||||
),
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
"ColQwen3": ("colqwen3", "ColQwen3Model"),
|
||||
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
|
||||
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
|
||||
"SiglipModel": ("siglip", "SiglipEmbeddingModel"),
|
||||
"LlamaNemotronVLModel": (
|
||||
"nemotron_vl",
|
||||
@@ -277,35 +265,59 @@ _EMBEDDING_MODELS = {
|
||||
"Terratorch": ("terratorch", "Terratorch"),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||
_LATE_INTERACTION_MODELS = {
|
||||
# [Text-only]
|
||||
"HF_ColBERT": ("colbert", "ColBERTModel"),
|
||||
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
|
||||
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
|
||||
# [Multimodal]
|
||||
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
|
||||
"ColQwen3": ("colqwen3", "ColQwen3Model"),
|
||||
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
|
||||
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
|
||||
}
|
||||
|
||||
_REWARD_MODELS = {
|
||||
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
|
||||
}
|
||||
|
||||
_TOKEN_CLASSIFICATION_MODELS = {
|
||||
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
|
||||
"ModernBertForTokenClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForTokenClassification",
|
||||
),
|
||||
}
|
||||
|
||||
_SEQUENCE_CLASSIFICATION_MODELS = {
|
||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
||||
"GteNewForSequenceClassification": (
|
||||
"bert_with_rope",
|
||||
"GteNewForSequenceClassification",
|
||||
),
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
|
||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||
"LlamaBidirectionalForSequenceClassification": (
|
||||
"llama",
|
||||
"LlamaBidirectionalForSequenceClassification",
|
||||
),
|
||||
"LlamaNemotronVLForSequenceClassification": (
|
||||
"nemotron_vl",
|
||||
"LlamaNemotronVLForSequenceClassification",
|
||||
),
|
||||
"ModernBertForSequenceClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForSequenceClassification",
|
||||
),
|
||||
"ModernBertForTokenClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForTokenClassification",
|
||||
),
|
||||
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
|
||||
"XLMRobertaForSequenceClassification": (
|
||||
"roberta",
|
||||
"RobertaForSequenceClassification",
|
||||
),
|
||||
# [Multimodal]
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
|
||||
"LlamaNemotronVLForSequenceClassification": (
|
||||
"nemotron_vl",
|
||||
"LlamaNemotronVLForSequenceClassification",
|
||||
),
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
@@ -606,7 +618,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
|
||||
_VLLM_MODELS = {
|
||||
**_TEXT_GENERATION_MODELS,
|
||||
**_EMBEDDING_MODELS,
|
||||
**_CROSS_ENCODER_MODELS,
|
||||
**_LATE_INTERACTION_MODELS,
|
||||
**_REWARD_MODELS,
|
||||
**_TOKEN_CLASSIFICATION_MODELS,
|
||||
**_SEQUENCE_CLASSIFICATION_MODELS,
|
||||
**_MULTIMODAL_MODELS,
|
||||
**_SPECULATIVE_DECODING_MODELS,
|
||||
**_TRANSFORMERS_SUPPORTED_MODELS,
|
||||
@@ -643,8 +658,7 @@ class _ModelInfo:
|
||||
attn_type: AttnTypeStr
|
||||
default_seq_pooling_type: SequencePoolingType
|
||||
default_tok_pooling_type: TokenPoolingType
|
||||
supports_cross_encoding: bool
|
||||
supports_late_interaction: bool
|
||||
score_type: ScoreType
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input_only: bool
|
||||
requires_raw_input_tokens: bool
|
||||
@@ -667,8 +681,7 @@ class _ModelInfo:
|
||||
default_seq_pooling_type=get_default_seq_pooling_type(model),
|
||||
default_tok_pooling_type=get_default_tok_pooling_type(model),
|
||||
attn_type=get_attn_type(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_late_interaction=supports_late_interaction(model),
|
||||
score_type=get_score_type(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
||||
model
|
||||
@@ -1166,14 +1179,6 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
||||
return model_cls.is_pooling_model
|
||||
|
||||
def is_cross_encoder_model(
|
||||
self,
|
||||
architectures: str | list[str],
|
||||
model_config: ModelConfig,
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
||||
return model_cls.supports_cross_encoding
|
||||
|
||||
def is_multimodal_model(
|
||||
self,
|
||||
architectures: str | list[str],
|
||||
|
||||
@@ -10,6 +10,12 @@ PoolingTask = Literal[
|
||||
]
|
||||
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
|
||||
|
||||
# Score API handles score/rerank for:
|
||||
# - "score" task (score_type: cross-encoder models)
|
||||
# - "embed" task (score_type: bi-encoder models)
|
||||
# - "token_embed" task (score_type: late interaction models)
|
||||
ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"]
|
||||
|
||||
FrontendTask = Literal["render"]
|
||||
FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user