diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 1cad8c4a1..0551d4670 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -821,6 +821,7 @@ The following table lists those that are tested in vLLM. | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|--------|-------------------|----------------------|---------------------------| | `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | +| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | | | `LlavaNextForConditionalGeneration`C | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | | `Phi3VForCausalLM`C | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | | `Qwen3VLForConditionalGeneration`C | Qwen3-VL | T + I + V | `Qwen/Qwen3-VL-Embedding-2B`, etc. | ✅︎ | ✅︎ | diff --git a/examples/pooling/score/colmodernvbert_rerank_online.py b/examples/pooling/score/colmodernvbert_rerank_online.py new file mode 100644 index 000000000..de827ae06 --- /dev/null +++ b/examples/pooling/score/colmodernvbert_rerank_online.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Example of using ColModernVBERT late interaction model for reranking. + +ColModernVBERT is a multi-modal ColBERT-style model combining a SigLIP +vision encoder with a ModernBERT text encoder. It produces per-token +embeddings and uses MaxSim scoring for retrieval and reranking. +Supports both text and image inputs. + +Start the server with: + vllm serve ModernVBERT/colmodernvbert-merged --max-model-len 8192 + +Then run this script: + python colmodernvbert_rerank_online.py +""" + +import requests + +MODEL = "ModernVBERT/colmodernvbert-merged" +BASE_URL = "http://127.0.0.1:8000" + +headers = {"accept": "application/json", "Content-Type": "application/json"} + +IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/300px-PNG_transparency_demonstration_1.png" # noqa: E501 + + +def rerank_text(): + """Text-only reranking via /rerank endpoint.""" + print("=" * 60) + print("1. Text reranking (/rerank)") + print("=" * 60) + + data = { + "model": MODEL, + "query": "What is machine learning?", + "documents": [ + "Machine learning is a subset of artificial intelligence.", + "Python is a programming language.", + "Deep learning uses neural networks for complex tasks.", + "The weather today is sunny.", + ], + } + + response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data) + + if response.status_code == 200: + result = response.json() + print("\n Ranked documents (most relevant first):") + for item in result["results"]: + doc_idx = item["index"] + score = item["relevance_score"] + print(f" [{score:.4f}] {data['documents'][doc_idx]}") + else: + print(f" Request failed: {response.status_code}") + print(f" {response.text[:300]}") + + +def score_text(): + """Text-only scoring via /score endpoint.""" + print() + print("=" * 60) + print("2. Text scoring (/score)") + print("=" * 60) + + query = "What is the capital of France?" + documents = [ + "The capital of France is Paris.", + "Berlin is the capital of Germany.", + "Python is a programming language.", + ] + + data = { + "model": MODEL, + "text_1": query, + "text_2": documents, + } + + response = requests.post(f"{BASE_URL}/score", headers=headers, json=data) + + if response.status_code == 200: + result = response.json() + print(f"\n Query: {query}\n") + for item in result["data"]: + idx = item["index"] + score = item["score"] + print(f" Doc {idx} (score={score:.4f}): {documents[idx]}") + else: + print(f" Request failed: {response.status_code}") + print(f" {response.text[:300]}") + + +def score_text_top_n(): + """Text reranking with top_n filtering via /rerank endpoint.""" + print() + print("=" * 60) + print("3. Text reranking with top_n=2 (/rerank)") + print("=" * 60) + + data = { + "model": MODEL, + "query": "What is the capital of France?", + "documents": [ + "The capital of France is Paris.", + "Berlin is the capital of Germany.", + "Python is a programming language.", + "The Eiffel Tower is in Paris.", + ], + "top_n": 2, + } + + response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data) + + if response.status_code == 200: + result = response.json() + print(f"\n Top {data['top_n']} results:") + for item in result["results"]: + doc_idx = item["index"] + score = item["relevance_score"] + print(f" [{score:.4f}] {data['documents'][doc_idx]}") + else: + print(f" Request failed: {response.status_code}") + print(f" {response.text[:300]}") + + +def rerank_multimodal(): + """Multimodal reranking with text and image documents via /rerank.""" + print() + print("=" * 60) + print("4. Multimodal reranking: text query vs image document (/rerank)") + print("=" * 60) + + data = { + "model": MODEL, + "query": "A colorful logo with transparency", + "documents": [ + {"content": [{"type": "image_url", "image_url": {"url": IMAGE_URL}}]}, + "Python is a programming language.", + "The weather today is sunny.", + ], + } + + response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data) + + if response.status_code == 200: + result = response.json() + print("\n Ranked documents (most relevant first):") + labels = ["[image]", "Python doc", "Weather doc"] + for item in result["results"]: + doc_idx = item["index"] + score = item["relevance_score"] + print(f" [{score:.4f}] {labels[doc_idx]}") + else: + print(f" Request failed: {response.status_code}") + print(f" {response.text[:300]}") + + +def main(): + rerank_text() + score_text() + score_text_top_n() + rerank_multimodal() + + +if __name__ == "__main__": + main() diff --git a/tests/models/multimodal/pooling/test_colmodernvbert.py b/tests/models/multimodal/pooling/test_colmodernvbert.py new file mode 100644 index 000000000..01f3843c3 --- /dev/null +++ b/tests/models/multimodal/pooling/test_colmodernvbert.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for ColModernVBERT multimodal late-interaction model. + +ColModernVBERT combines SigLIP vision encoder + ModernBERT text encoder +with a pixel shuffle connector and ColBERT-style 128-dim per-token +embeddings for visual document retrieval. +""" + +import pytest +import torch + +from vllm.entrypoints.pooling.score.utils import compute_maxsim_score + +MODEL_NAME = "ModernVBERT/colmodernvbert-merged" +COLBERT_DIM = 128 +DTYPE = "half" + + +# ----------------------------------------------------------------------- +# Text-only tests +# ----------------------------------------------------------------------- + + +def test_colmodernvbert_text_token_embed(vllm_runner): + """Text query produces per-token embeddings with shape (seq_len, 128).""" + with vllm_runner( + MODEL_NAME, + runner="pooling", + dtype=DTYPE, + enforce_eager=True, + ) as vllm_model: + outputs = vllm_model.token_embed(["What is machine learning?"]) + + assert len(outputs) == 1 + emb = torch.tensor(outputs[0]) + assert emb.dim() == 2 + assert emb.shape[1] == COLBERT_DIM + assert emb.shape[0] > 1 + + +def test_colmodernvbert_text_relevance_ordering(vllm_runner): + """Relevant documents score higher than irrelevant ones.""" + query = "What is machine learning?" + documents = [ + "Machine learning is a subset of artificial intelligence.", + "The weather in Paris is mild in spring.", + ] + + with vllm_runner( + MODEL_NAME, + runner="pooling", + dtype=DTYPE, + enforce_eager=True, + ) as vllm_model: + scores = vllm_model.score(query, documents) + + assert len(scores) == 2 + assert scores[0] > scores[1], "ML doc should score higher than weather doc" + + +def test_colmodernvbert_text_late_interaction(vllm_runner): + """MaxSim scoring via vLLM matches manual computation.""" + query = "What is the capital of France?" + doc = "The capital of France is Paris." + + with vllm_runner( + MODEL_NAME, + runner="pooling", + dtype=DTYPE, + enforce_eager=True, + ) as vllm_model: + q_out = vllm_model.token_embed([query]) + d_out = vllm_model.token_embed([doc]) + + q_emb = torch.tensor(q_out[0]) + d_emb = torch.tensor(d_out[0]) + manual_score = compute_maxsim_score(q_emb, d_emb).item() + + vllm_scores = vllm_model.score(query, doc) + + assert len(vllm_scores) == 1 + assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01) + + +# ----------------------------------------------------------------------- +# Image tests +# ----------------------------------------------------------------------- + + +def test_colmodernvbert_image_token_embed(vllm_runner, image_assets): + """Image input produces per-token embeddings including vision tokens.""" + with vllm_runner( + MODEL_NAME, + runner="pooling", + dtype=DTYPE, + enforce_eager=True, + ) as vllm_model: + image = image_assets[0].pil_image + inputs = vllm_model.get_inputs( + [""], + images=[image], + ) + req_outputs = vllm_model.llm.encode( + inputs, + pooling_task="token_embed", + ) + outputs = [req_output.outputs.data for req_output in req_outputs] + + assert len(outputs) == 1 + emb = torch.tensor(outputs[0]) + assert emb.dim() == 2 + assert emb.shape[1] == COLBERT_DIM + # Should have at least the image tokens (64 after pixel shuffle) + assert emb.shape[0] >= 64 diff --git a/tests/models/registry.py b/tests/models/registry.py index b37dfb6d8..64a0794b8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -592,6 +592,9 @@ _EMBEDDING_EXAMPLE_MODELS = { ), # [Multimodal] "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), + "ColModernVBertForRetrieval": _HfExamplesInfo( + "ModernVBERT/colmodernvbert-merged", + ), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo( "TIGER-Lab/VLM2Vec-Full", trust_remote_code=True diff --git a/vllm/model_executor/models/colmodernvbert.py b/vllm/model_executor/models/colmodernvbert.py new file mode 100644 index 000000000..29efb4a5f --- /dev/null +++ b/vllm/model_executor/models/colmodernvbert.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ColModernVBERT: multimodal late-interaction retrieval model. + +Combines SigLIP vision encoder + ModernBERT text encoder with a pixel +shuffle connector and ColBERT-style 128-dim per-token embeddings. + +Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged +""" + +from collections.abc import Iterable, Mapping, Sequence +from typing import ClassVar, Literal + +import torch +from torch import nn +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptReplacement, + PromptUpdate, +) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces_base import default_pooling_type +from .modernbert import ModernBertEmbeddings, ModernBertLayer +from .siglip import SiglipVisionModel +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix + +# --------------------------------------------------------------------------- +# Connector: pixel shuffle + simple linear projection +# --------------------------------------------------------------------------- + + +class ColModernVBertConnector(nn.Module): + """Pixel shuffle spatial reduction followed by a linear projection. + + Reduces the vision encoder's token count by ``factor^2`` via pixel-shuffle + spatial rearrangement, then projects the concatenated channels to the text + encoder's hidden size with a single bias-free linear layer. + """ + + def __init__(self, config: ColModernVBertConfig): + super().__init__() + self.pixel_shuffle_factor = config.pixel_shuffle_factor + vision_hidden_size = config.vision_config.hidden_size + input_size = vision_hidden_size * (self.pixel_shuffle_factor**2) + output_size = config.hidden_size + self.proj = nn.Linear(input_size, output_size, bias=False) + + def pixel_shuffle(self, features: torch.Tensor) -> torch.Tensor: + """Spatial rearrangement that reduces seq length by factor^2.""" + batch_size, seq_length, hidden_size = features.shape + height = width = int(seq_length**0.5) + factor = self.pixel_shuffle_factor + + # Reshape to (B, H, W, C) + features = features.view(batch_size, height, width, hidden_size) + + # Reshape to (B, H/f, f, W/f, f, C) + features = features.view( + batch_size, height // factor, factor, width // factor, factor, hidden_size + ) + + # Permute to (B, H/f, W/f, f, f, C) + features = features.permute(0, 1, 3, 2, 4, 5) + + # Reshape to (B, H/f, W/f, C * f^2) + new_hidden_size = hidden_size * (factor**2) + features = features.reshape( + batch_size, height // factor, width // factor, new_hidden_size + ) + + return features + + def forward(self, features: torch.Tensor) -> torch.Tensor: + features = self.pixel_shuffle(features) + batch_size = features.shape[0] + features = features.reshape(batch_size, -1, features.shape[-1]) + return self.proj(features) + + +# --------------------------------------------------------------------------- +# Multimodal processing +# --------------------------------------------------------------------------- + + +class ColModernVBertProcessingInfo(BaseProcessingInfo): + def get_hf_config(self) -> ColModernVBertConfig: + return self.ctx.get_hf_config(ColModernVBertConfig) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_image_size_with_most_features(self) -> ImageSize: + config = self.get_hf_config() + size = config.vision_config.image_size + return ImageSize(width=size, height=size) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return self.get_hf_config().image_seq_len + + +class ColModernVBertDummyInputsBuilder( + BaseDummyInputsBuilder[ColModernVBertProcessingInfo], +): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + mm_processor_kwargs: Mapping[str, object] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + target_width, target_height = self.info.get_image_size_with_most_features() + image_overrides = mm_options.get("image") if mm_options else None + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class ColModernVBertMultiModalProcessor( + BaseMultiModalProcessor[ColModernVBertProcessingInfo], +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + text_encoding = tokenizer( + prompt, + return_tensors="pt", + **tok_kwargs, + ) + result = BatchFeature(data=dict(text_encoding)) + + images = mm_data.get("images") + if images: + from transformers import Idefics3ImageProcessor + + image_processor = Idefics3ImageProcessor.from_pretrained( + self.info.ctx.model_config.model, + revision=self.info.ctx.model_config.revision, + ) + image_outputs = image_processor( + images=images, + do_image_splitting=False, + return_tensors="pt", + ) + result.update(image_outputs) + + return result + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + config = self.info.get_hf_config() + image_token_id = config.image_token_id + num_tokens = config.image_seq_len + + def get_replacement(item_idx: int): + return [image_token_id] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=PromptIndexTargets.start(), + replacement=get_replacement, + ), + ] + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +@MULTIMODAL_REGISTRY.register_processor( + ColModernVBertMultiModalProcessor, + info=ColModernVBertProcessingInfo, + dummy_inputs=ColModernVBertDummyInputsBuilder, +) +@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL") +class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal): + """ColModernVBERT multimodal late-interaction retrieval model. + + Architecture: + Image -> SiglipVisionModel -> ColModernVBertConnector + ↓ + Text -> ModernBertEmbeddings → [merge] → ModernBertLayers → norm + ↓ + custom_text_proj → L2 norm + ↓ + per-token 128-d embeddings + """ + + is_pooling_model = True + supports_late_interaction: ClassVar[Literal[True]] = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: ColModernVBertConfig = vllm_config.model_config.hf_config + self.config = config + text_config = config.text_config + quant_config = vllm_config.quant_config + + # --- Vision encoder (reuses SiglipVisionModel from siglip.py) --- + self.vision_model = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + # --- Connector (pixel shuffle + linear projection) --- + self.connector = ColModernVBertConnector(config) + + # --- Text encoder (built from ModernBERT components directly) --- + # We build the components individually rather than wrapping + # ``ModernBertModel`` because ``ModernBertEncoderLayer`` reads + # ``vllm_config.model_config.hf_config`` which would be + # ``ColModernVBertConfig``, not ``ModernBertConfig``. + self.text_embeddings = ModernBertEmbeddings(text_config) + self.text_layers = nn.ModuleList( + [ + ModernBertLayer( + config=text_config, + layer_id=i, + prefix=f"{prefix}.text_layers.{i}", + ) + for i in range(text_config.num_hidden_layers) + ] + ) + self.text_final_norm = nn.LayerNorm( + text_config.hidden_size, + eps=text_config.norm_eps, + bias=text_config.norm_bias, + ) + + # --- ColBERT projection (768 -> 128, with bias) --- + self.custom_text_proj = nn.Linear( + text_config.hidden_size, + config.embedding_dim, + bias=True, + dtype=vllm_config.model_config.head_dtype, + ) + + # --- Pooler (applies projection + L2 normalize) --- + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler = pooler_for_token_embed( + pooler_config, + projector=self.custom_text_proj, + ) + + # ---- multimodal --------------------------------------------------------- + + def _get_image_features( + self, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + # Idefics3ImageProcessor may return (batch, tiles, C, H, W); + # flatten to (batch*tiles, C, H, W) for SiglipVisionModel. + if pixel_values.dim() == 5: + b, t, c, h, w = pixel_values.shape + pixel_values = pixel_values.reshape(b * t, c, h, w) + vision_outputs = self.vision_model( + pixel_values.to(dtype=self.vision_model.dtype), + ) + return self.connector(vision_outputs) + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return [] + assert isinstance(pixel_values, torch.Tensor) + image_features = self._get_image_features(pixel_values) + return list(image_features) + + # ---- forward ------------------------------------------------------------ + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.text_embeddings(input_ids, inputs_embeds=inputs_embeds) + + for layer in self.text_layers: + hidden_states = layer(hidden_states, positions) + + return self.text_final_norm(hidden_states) + + # ---- weight loading ----------------------------------------------------- + + # Checkpoint prefix → vLLM param prefix. + # More-specific prefixes must appear before shorter ones. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.text_model.layers.": "text_layers.", + "model.text_model.embeddings.": "text_embeddings.", + "model.text_model.final_norm.": "text_final_norm.", + "model.connector.modality_projection.": "connector.", + "model.custom_text_proj.": "custom_text_proj.", + "model.vision_model.": "vision_model.vision_model.", + "model.": "", + }, + ) + + # Checkpoint names for DecoupledEmbedding parts + _BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight" + _EXTRA_EMB = ( + "model.text_model.embeddings.tok_embeddings.additional_embedding.weight" + ) + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + # DecoupledEmbedding requires concatenating base + additional + # embedding tensors before loading, so we extract them first. + base_embedding_weight: torch.Tensor | None = None + additional_embedding_weight: torch.Tensor | None = None + remaining: list[tuple[str, torch.Tensor]] = [] + + for name, tensor in weights: + if name == self._BASE_EMB: + base_embedding_weight = tensor + elif name == self._EXTRA_EMB: + additional_embedding_weight = tensor + else: + remaining.append((name, tensor)) + + # Load all non-embedding weights via AutoWeightsLoader + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights( + remaining, + mapper=self.hf_to_vllm_mapper, + ) + + # Concatenate and load DecoupledEmbedding weights + if base_embedding_weight is not None: + combined = base_embedding_weight + if additional_embedding_weight is not None: + combined = torch.cat( + [base_embedding_weight, additional_embedding_weight], + dim=0, + ) + param_name = "text_embeddings.tok_embeddings.weight" + params_dict = dict(self.named_parameters()) + if param_name in params_dict: + param = params_dict[param_name] + weight_loader = getattr( + param, + "weight_loader", + default_weight_loader, + ) + weight_loader(param, combined) + loaded_params.add(param_name) + elif additional_embedding_weight is not None: + raise ValueError( + "Found 'text_model.embeddings.tok_embeddings" + ".additional_embedding.weight' but not " + "'text_model.embeddings.tok_embeddings.weight'" + ) + + # The pooler wraps ``custom_text_proj`` as its head projector. + # Mark those params as loaded under the pooler path too. + if hasattr(self, "pooler") and hasattr(self.pooler, "head"): + head = self.pooler.head + projector = getattr(head, "projector", None) + if projector is not None and isinstance(projector, nn.Module): + for pname, _ in projector.named_parameters(): + loaded_params.add(f"pooler.head.projector.{pname}") + + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 598df91d9..329411d62 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -248,6 +248,7 @@ _EMBEDDING_MODELS = { "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), # [Multimodal] "CLIPModel": ("clip", "CLIPEmbeddingModel"), + "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"), "LlavaNextForConditionalGeneration": ( "llava_next", "LlavaNextForConditionalGeneration", diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 852e1d2a3..00129d52e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -74,6 +74,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( afmoe="AfmoeConfig", bagel="BagelConfig", chatglm="ChatGLMConfig", + colmodernvbert="ColModernVBertConfig", colqwen3="ColQwen3Config", ops_colqwen3="OpsColQwen3Config", qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index d02ab01d7..541bc4de6 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -18,6 +18,7 @@ _CLASS_TO_MODULE: dict[str, str] = { "AfmoeConfig": "vllm.transformers_utils.configs.afmoe", "BagelConfig": "vllm.transformers_utils.configs.bagel", "ChatGLMConfig": "vllm.transformers_utils.configs.chatglm", + "ColModernVBertConfig": "vllm.transformers_utils.configs.colmodernvbert", "ColQwen3Config": "vllm.transformers_utils.configs.colqwen3", "OpsColQwen3Config": "vllm.transformers_utils.configs.colqwen3", "Qwen3VLNemotronEmbedConfig": "vllm.transformers_utils.configs.colqwen3", @@ -71,6 +72,7 @@ __all__ = [ "AfmoeConfig", "BagelConfig", "ChatGLMConfig", + "ColModernVBertConfig", "ColQwen3Config", "OpsColQwen3Config", "Qwen3VLNemotronEmbedConfig", diff --git a/vllm/transformers_utils/configs/colmodernvbert.py b/vllm/transformers_utils/configs/colmodernvbert.py new file mode 100644 index 000000000..97fad16bc --- /dev/null +++ b/vllm/transformers_utils/configs/colmodernvbert.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Configuration for ColModernVBERT visual document retrieval model. + +ColModernVBERT combines SigLIP vision encoder + ModernBERT text encoder +with a pixel shuffle connector and ColBERT-style 128-dim per-token embeddings. + +Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged +""" + +from transformers import ModernBertConfig, PretrainedConfig, SiglipVisionConfig + + +class ColModernVBertConfig(PretrainedConfig): + model_type = "colmodernvbert" + + def __init__( + self, + embedding_dim: int = 128, + vlm_config: dict | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.embedding_dim = embedding_dim + + if vlm_config is None: + vlm_config = {} + + # Top-level VLM fields + self.image_token_id = vlm_config.get("image_token_id", 50407) + self.pixel_shuffle_factor = vlm_config.get("pixel_shuffle_factor", 4) + self.hidden_size = vlm_config.get("hidden_size", 768) + additional_vocab_size = vlm_config.get("additional_vocab_size", 40) + + # Text config (ModernBERT) + text_cfg = vlm_config.get("text_config", {}) + base_vocab = text_cfg.get("vocab_size", 50368) + self.text_config = ModernBertConfig( + vocab_size=base_vocab + additional_vocab_size, + hidden_size=text_cfg.get("hidden_size", 768), + intermediate_size=text_cfg.get("intermediate_size", 1152), + num_hidden_layers=text_cfg.get("num_hidden_layers", 22), + num_attention_heads=text_cfg.get("num_attention_heads", 12), + mlp_bias=text_cfg.get("mlp_bias", False), + max_position_embeddings=vlm_config.get("max_position_embeddings", 8192), + ) + + # Vision config (SigLIP) + vis_cfg = vlm_config.get("vision_config", {}) + self.vision_config = SiglipVisionConfig( + hidden_size=vis_cfg.get("embed_dim", 768), + image_size=vis_cfg.get("image_size", 512), + patch_size=vis_cfg.get("patch_size", 16), + num_hidden_layers=vis_cfg.get("num_hidden_layers", 12), + intermediate_size=vis_cfg.get("intermediate_size", 3072), + num_attention_heads=vis_cfg.get("num_attention_heads", 12), + ) + + @property + def image_seq_len(self) -> int: + ps = self.vision_config.image_size // self.vision_config.patch_size + return (ps * ps) // (self.pixel_shuffle_factor**2) + + def get_text_config(self, **kwargs): + return self.text_config