[Model] Consolidate score logic by introduce score_type (#36479)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-03-10 21:32:25 +08:00
committed by GitHub
parent 409c4e632d
commit a3189a08b0
14 changed files with 213 additions and 194 deletions

View File

@@ -18,7 +18,6 @@ Reference: https://arxiv.org/abs/2004.12832
"""
from collections.abc import Iterable
from typing import ClassVar, Literal
import torch
from torch import nn
@@ -28,16 +27,16 @@ from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from .bert import BertEmbeddingModel, BertModel
from .interfaces import SupportsLateInteraction
from .interfaces_base import default_pooling_type
class ColBERTMixin:
class ColBERTMixin(nn.Module, SupportsLateInteraction):
"""Mixin that adds ColBERT late interaction support to any embedding model.
ColBERT (Contextualized Late Interaction over BERT) uses per-token
embeddings with a linear projection layer. This mixin provides:
- ``supports_late_interaction`` class-var
- ColBERT linear projection initialisation / lazy creation
- Weight loading helpers for the projection layer
- A builder for the token-embedding pooler
@@ -52,8 +51,6 @@ class ColBERTMixin:
the ColBERT projection weight, then delegate the rest to the backbone.
"""
supports_late_interaction: ClassVar[Literal[True]] = True
# Set during _init_colbert_components
colbert_dim: int | None
colbert_linear: nn.Linear | None

View File

@@ -9,7 +9,6 @@ Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
"""
from collections.abc import Iterable, Mapping, Sequence
from typing import ClassVar, Literal
import torch
from torch import nn
@@ -37,7 +36,11 @@ from vllm.multimodal.processing import (
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces import (
MultiModalEmbeddings,
SupportsLateInteraction,
SupportsMultiModal,
)
from .interfaces_base import default_pooling_type
from .modernbert import ModernBertEmbeddings, ModernBertLayer
from .siglip import SiglipVisionModel
@@ -234,7 +237,9 @@ class ColModernVBertMultiModalProcessor(
dummy_inputs=ColModernVBertDummyInputsBuilder,
)
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal):
class ColModernVBertForRetrieval(
nn.Module, SupportsMultiModal, SupportsLateInteraction
):
"""ColModernVBERT multimodal late-interaction retrieval model.
Architecture:
@@ -248,7 +253,6 @@ class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal):
"""
is_pooling_model = True
supports_late_interaction: ClassVar[Literal[True]] = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@@ -20,7 +20,6 @@ Target models:
"""
from collections.abc import Iterable, Mapping
from typing import ClassVar, Literal
import torch
import torch.nn as nn
@@ -31,6 +30,7 @@ from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from .interfaces import SupportsLateInteraction
from .interfaces_base import default_pooling_type
from .qwen2_vl import Qwen2VLMultiModalDataParser
from .qwen3_vl import (
@@ -113,9 +113,7 @@ class ColQwen3ProcessingInfo(Qwen3VLProcessingInfo):
info=ColQwen3ProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class ColQwen3Model(
Qwen3VLForConditionalGeneration,
):
class ColQwen3Model(Qwen3VLForConditionalGeneration, SupportsLateInteraction):
"""ColQwen3 late interaction model for multi-modal retrieval/reranking.
This model extends Qwen3VLForConditionalGeneration with a ColBERT-style
@@ -132,16 +130,11 @@ class ColQwen3Model(
Attributes:
custom_text_proj: Linear projection from hidden_size to embed_dim
supports_late_interaction: Flag indicating this model uses late
interaction scoring
"""
# Mark this as a pooling model so vLLM routes to pooler path
is_pooling_model = True
# Mark this model as supporting late interaction scoring
supports_late_interaction: ClassVar[Literal[True]] = True
# Override hf_to_vllm_mapper to handle ColQwen3 weight naming.
# NOTE: WeightsMapper applies ALL matching prefix rules sequentially
# (no early exit), so more-specific prefixes must come first.

View File

@@ -34,10 +34,11 @@ from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.tasks import ScoreType
from vllm.utils.collection_utils import common_prefix
from vllm.utils.func_utils import supports_kw
from .interfaces_base import VllmModel, is_pooling_model
from .interfaces_base import VllmModel
if TYPE_CHECKING:
from vllm.config import VllmConfig
@@ -969,29 +970,7 @@ def supports_mamba_prefix_caching(
class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding."""
supports_cross_encoding: ClassVar[Literal[True]] = True
@overload
def supports_cross_encoding(
model: type[object],
) -> TypeIs[type[SupportsCrossEncoding]]: ...
@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ...
def _supports_cross_encoding(
model: type[object] | object,
) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]:
return getattr(model, "supports_cross_encoding", False)
def supports_cross_encoding(
model: type[object] | object,
) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]:
return is_pooling_model(model) and _supports_cross_encoding(model)
score_type: ClassVar[ScoreType] = "cross-encoder"
@runtime_checkable
@@ -1003,29 +982,7 @@ class SupportsLateInteraction(Protocol):
MaxSim (max over document tokens, sum over query tokens).
"""
supports_late_interaction: ClassVar[Literal[True]] = True
@overload
def supports_late_interaction(
model: type[object],
) -> TypeIs[type[SupportsLateInteraction]]: ...
@overload
def supports_late_interaction(model: object) -> TypeIs[SupportsLateInteraction]: ...
def _supports_late_interaction(
model: type[object] | object,
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
return getattr(model, "supports_late_interaction", False)
def supports_late_interaction(
model: type[object] | object,
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
return is_pooling_model(model) and _supports_late_interaction(model)
score_type: ClassVar[ScoreType] = "late-interaction"
class SupportsQuant:

View File

@@ -15,6 +15,7 @@ import torch.nn as nn
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.tasks import ScoreType
from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING:
@@ -187,6 +188,26 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
decorator to conveniently set this field.
"""
score_type: ClassVar[ScoreType] = "bi-encoder"
"""
Indicates the
[vllm.config.model.ModelConfig.score_type][]
to use by default.
Score API handles score/rerank for:
- "score" task (score_type: cross-encoder models)
- "embed" task (score_type: bi-encoder models)
- "token_embed" task (score_type: late interaction models)
score_type defaults to bi-encoder, then the Score API uses the "embed" task.
If you set score_type to cross-encoder via
[vllm.model_executor.models.interfaces.SupportsCrossEncoding][],
then the Score API uses the "score" task.
If you set score_type to late-interaction via
[vllm.model_executor.models.interfaces.SupportsLateInteraction][],
then the Score API uses the "token_embed" task.
"""
pooler: Pooler
"""The pooler is only called on TP rank 0."""
@@ -250,3 +271,13 @@ def attn_type(attn_type: AttnTypeStr):
def get_attn_type(model: type[object] | object) -> AttnTypeStr:
return getattr(model, "attn_type", "decoder")
def get_score_type(model: type[object] | object) -> ScoreType:
score_types = set()
for m in model.__mro__:
score_type = getattr(m, "score_type", "bi-encoder")
if score_type != "bi-encoder":
score_types.add(score_type)
assert len(score_types) < 2
return "bi-encoder" if not score_types else list(score_types)[0]

View File

@@ -30,6 +30,7 @@ from vllm.config import (
)
from vllm.logger import init_logger
from vllm.logging_utils import logtime
from vllm.tasks import ScoreType
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
from vllm.utils.hashing import safe_hash
@@ -48,8 +49,6 @@ from .interfaces import (
is_attention_free,
is_hybrid,
requires_raw_input_tokens,
supports_cross_encoding,
supports_late_interaction,
supports_mamba_prefix_caching,
supports_multimodal,
supports_multimodal_encoder_tp_data,
@@ -61,6 +60,7 @@ from .interfaces_base import (
get_attn_type,
get_default_seq_pooling_type,
get_default_tok_pooling_type,
get_score_type,
is_pooling_model,
is_text_generation_model,
)
@@ -214,19 +214,14 @@ _EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"HF_ColBERT": ("colbert", "ColBERTModel"),
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
"GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
"GteNewModel": ("bert_with_rope", "GteNewModel"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
@@ -241,8 +236,6 @@ _EMBEDDING_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
@@ -252,19 +245,14 @@ _EMBEDDING_MODELS = {
"VoyageQwen3BidirectionalEmbedModel",
),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
# [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
"LlavaNextForConditionalGeneration": (
"llava_next",
"LlavaNextForConditionalGeneration",
),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"ColQwen3": ("colqwen3", "ColQwen3Model"),
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
"SiglipModel": ("siglip", "SiglipEmbeddingModel"),
"LlamaNemotronVLModel": (
"nemotron_vl",
@@ -277,35 +265,59 @@ _EMBEDDING_MODELS = {
"Terratorch": ("terratorch", "Terratorch"),
}
_CROSS_ENCODER_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
_LATE_INTERACTION_MODELS = {
# [Text-only]
"HF_ColBERT": ("colbert", "ColBERTModel"),
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
# [Multimodal]
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
"ColQwen3": ("colqwen3", "ColQwen3Model"),
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
}
_REWARD_MODELS = {
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
}
_TOKEN_CLASSIFICATION_MODELS = {
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
"ModernBertForTokenClassification": (
"modernbert",
"ModernBertForTokenClassification",
),
}
_SEQUENCE_CLASSIFICATION_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
"GteNewForSequenceClassification": (
"bert_with_rope",
"GteNewForSequenceClassification",
),
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaBidirectionalForSequenceClassification": (
"llama",
"LlamaBidirectionalForSequenceClassification",
),
"LlamaNemotronVLForSequenceClassification": (
"nemotron_vl",
"LlamaNemotronVLForSequenceClassification",
),
"ModernBertForSequenceClassification": (
"modernbert",
"ModernBertForSequenceClassification",
),
"ModernBertForTokenClassification": (
"modernbert",
"ModernBertForTokenClassification",
),
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": (
"roberta",
"RobertaForSequenceClassification",
),
# [Multimodal]
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
"LlamaNemotronVLForSequenceClassification": (
"nemotron_vl",
"LlamaNemotronVLForSequenceClassification",
),
}
_MULTIMODAL_MODELS = {
@@ -606,7 +618,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
_VLLM_MODELS = {
**_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_CROSS_ENCODER_MODELS,
**_LATE_INTERACTION_MODELS,
**_REWARD_MODELS,
**_TOKEN_CLASSIFICATION_MODELS,
**_SEQUENCE_CLASSIFICATION_MODELS,
**_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS,
**_TRANSFORMERS_SUPPORTED_MODELS,
@@ -643,8 +658,7 @@ class _ModelInfo:
attn_type: AttnTypeStr
default_seq_pooling_type: SequencePoolingType
default_tok_pooling_type: TokenPoolingType
supports_cross_encoding: bool
supports_late_interaction: bool
score_type: ScoreType
supports_multimodal: bool
supports_multimodal_raw_input_only: bool
requires_raw_input_tokens: bool
@@ -667,8 +681,7 @@ class _ModelInfo:
default_seq_pooling_type=get_default_seq_pooling_type(model),
default_tok_pooling_type=get_default_tok_pooling_type(model),
attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_late_interaction=supports_late_interaction(model),
score_type=get_score_type(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
model
@@ -1166,14 +1179,6 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures, model_config)
return model_cls.is_pooling_model
def is_cross_encoder_model(
self,
architectures: str | list[str],
model_config: ModelConfig,
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures, model_config)
return model_cls.supports_cross_encoding
def is_multimodal_model(
self,
architectures: str | list[str],