[New Model] Add ColModernVBERT (#34558)

Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
Signed-off-by: athrael-soju <athrael-soju@users.noreply.github.com>
This commit is contained in:
Athrael Soju
2026-02-22 04:23:41 +00:00
committed by GitHub
parent d24bdd7c4b
commit 970861ac0c
9 changed files with 784 additions and 0 deletions

View File

@@ -821,6 +821,7 @@ The following table lists those that are tested in vLLM.
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |
| `Qwen3VLForConditionalGeneration`<sup>C</sup> | Qwen3-VL | T + I + V | `Qwen/Qwen3-VL-Embedding-2B`, etc. | ✅︎ | ✅︎ |

View File

@@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using ColModernVBERT late interaction model for reranking.
ColModernVBERT is a multi-modal ColBERT-style model combining a SigLIP
vision encoder with a ModernBERT text encoder. It produces per-token
embeddings and uses MaxSim scoring for retrieval and reranking.
Supports both text and image inputs.
Start the server with:
vllm serve ModernVBERT/colmodernvbert-merged --max-model-len 8192
Then run this script:
python colmodernvbert_rerank_online.py
"""
import requests
MODEL = "ModernVBERT/colmodernvbert-merged"
BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"}
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/300px-PNG_transparency_demonstration_1.png" # noqa: E501
def rerank_text():
"""Text-only reranking via /rerank endpoint."""
print("=" * 60)
print("1. Text reranking (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
],
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print("\n Ranked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text():
"""Text-only scoring via /score endpoint."""
print()
print("=" * 60)
print("2. Text scoring (/score)")
print("=" * 60)
query = "What is the capital of France?"
documents = [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
]
data = {
"model": MODEL,
"text_1": query,
"text_2": documents,
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Query: {query}\n")
for item in result["data"]:
idx = item["index"]
score = item["score"]
print(f" Doc {idx} (score={score:.4f}): {documents[idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text_top_n():
"""Text reranking with top_n filtering via /rerank endpoint."""
print()
print("=" * 60)
print("3. Text reranking with top_n=2 (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is the capital of France?",
"documents": [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
"The Eiffel Tower is in Paris.",
],
"top_n": 2,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Top {data['top_n']} results:")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def rerank_multimodal():
"""Multimodal reranking with text and image documents via /rerank."""
print()
print("=" * 60)
print("4. Multimodal reranking: text query vs image document (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "A colorful logo with transparency",
"documents": [
{"content": [{"type": "image_url", "image_url": {"url": IMAGE_URL}}]},
"Python is a programming language.",
"The weather today is sunny.",
],
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print("\n Ranked documents (most relevant first):")
labels = ["[image]", "Python doc", "Weather doc"]
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {labels[doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def main():
rerank_text()
score_text()
score_text_top_n()
rerank_multimodal()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ColModernVBERT multimodal late-interaction model.
ColModernVBERT combines SigLIP vision encoder + ModernBERT text encoder
with a pixel shuffle connector and ColBERT-style 128-dim per-token
embeddings for visual document retrieval.
"""
import pytest
import torch
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
COLBERT_DIM = 128
DTYPE = "half"
# -----------------------------------------------------------------------
# Text-only tests
# -----------------------------------------------------------------------
def test_colmodernvbert_text_token_embed(vllm_runner):
"""Text query produces per-token embeddings with shape (seq_len, 128)."""
with vllm_runner(
MODEL_NAME,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
) as vllm_model:
outputs = vllm_model.token_embed(["What is machine learning?"])
assert len(outputs) == 1
emb = torch.tensor(outputs[0])
assert emb.dim() == 2
assert emb.shape[1] == COLBERT_DIM
assert emb.shape[0] > 1
def test_colmodernvbert_text_relevance_ordering(vllm_runner):
"""Relevant documents score higher than irrelevant ones."""
query = "What is machine learning?"
documents = [
"Machine learning is a subset of artificial intelligence.",
"The weather in Paris is mild in spring.",
]
with vllm_runner(
MODEL_NAME,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
) as vllm_model:
scores = vllm_model.score(query, documents)
assert len(scores) == 2
assert scores[0] > scores[1], "ML doc should score higher than weather doc"
def test_colmodernvbert_text_late_interaction(vllm_runner):
"""MaxSim scoring via vLLM matches manual computation."""
query = "What is the capital of France?"
doc = "The capital of France is Paris."
with vllm_runner(
MODEL_NAME,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
) as vllm_model:
q_out = vllm_model.token_embed([query])
d_out = vllm_model.token_embed([doc])
q_emb = torch.tensor(q_out[0])
d_emb = torch.tensor(d_out[0])
manual_score = compute_maxsim_score(q_emb, d_emb).item()
vllm_scores = vllm_model.score(query, doc)
assert len(vllm_scores) == 1
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
# -----------------------------------------------------------------------
# Image tests
# -----------------------------------------------------------------------
def test_colmodernvbert_image_token_embed(vllm_runner, image_assets):
"""Image input produces per-token embeddings including vision tokens."""
with vllm_runner(
MODEL_NAME,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
) as vllm_model:
image = image_assets[0].pil_image
inputs = vllm_model.get_inputs(
[""],
images=[image],
)
req_outputs = vllm_model.llm.encode(
inputs,
pooling_task="token_embed",
)
outputs = [req_output.outputs.data for req_output in req_outputs]
assert len(outputs) == 1
emb = torch.tensor(outputs[0])
assert emb.dim() == 2
assert emb.shape[1] == COLBERT_DIM
# Should have at least the image tokens (64 after pixel shuffle)
assert emb.shape[0] >= 64

View File

@@ -592,6 +592,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
),
# [Multimodal]
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
"ColModernVBertForRetrieval": _HfExamplesInfo(
"ModernVBERT/colmodernvbert-merged",
),
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo(
"TIGER-Lab/VLM2Vec-Full", trust_remote_code=True

View File

@@ -0,0 +1,430 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""ColModernVBERT: multimodal late-interaction retrieval model.
Combines SigLIP vision encoder + ModernBERT text encoder with a pixel
shuffle connector and ColBERT-style 128-dim per-token embeddings.
Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
"""
from collections.abc import Iterable, Mapping, Sequence
from typing import ClassVar, Literal
import torch
from torch import nn
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptIndexTargets,
PromptReplacement,
PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type
from .modernbert import ModernBertEmbeddings, ModernBertLayer
from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
# ---------------------------------------------------------------------------
# Connector: pixel shuffle + simple linear projection
# ---------------------------------------------------------------------------
class ColModernVBertConnector(nn.Module):
"""Pixel shuffle spatial reduction followed by a linear projection.
Reduces the vision encoder's token count by ``factor^2`` via pixel-shuffle
spatial rearrangement, then projects the concatenated channels to the text
encoder's hidden size with a single bias-free linear layer.
"""
def __init__(self, config: ColModernVBertConfig):
super().__init__()
self.pixel_shuffle_factor = config.pixel_shuffle_factor
vision_hidden_size = config.vision_config.hidden_size
input_size = vision_hidden_size * (self.pixel_shuffle_factor**2)
output_size = config.hidden_size
self.proj = nn.Linear(input_size, output_size, bias=False)
def pixel_shuffle(self, features: torch.Tensor) -> torch.Tensor:
"""Spatial rearrangement that reduces seq length by factor^2."""
batch_size, seq_length, hidden_size = features.shape
height = width = int(seq_length**0.5)
factor = self.pixel_shuffle_factor
# Reshape to (B, H, W, C)
features = features.view(batch_size, height, width, hidden_size)
# Reshape to (B, H/f, f, W/f, f, C)
features = features.view(
batch_size, height // factor, factor, width // factor, factor, hidden_size
)
# Permute to (B, H/f, W/f, f, f, C)
features = features.permute(0, 1, 3, 2, 4, 5)
# Reshape to (B, H/f, W/f, C * f^2)
new_hidden_size = hidden_size * (factor**2)
features = features.reshape(
batch_size, height // factor, width // factor, new_hidden_size
)
return features
def forward(self, features: torch.Tensor) -> torch.Tensor:
features = self.pixel_shuffle(features)
batch_size = features.shape[0]
features = features.reshape(batch_size, -1, features.shape[-1])
return self.proj(features)
# ---------------------------------------------------------------------------
# Multimodal processing
# ---------------------------------------------------------------------------
class ColModernVBertProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> ColModernVBertConfig:
return self.ctx.get_hf_config(ColModernVBertConfig)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_image_size_with_most_features(self) -> ImageSize:
config = self.get_hf_config()
size = config.vision_config.image_size
return ImageSize(width=size, height=size)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
return self.get_hf_config().image_seq_len
class ColModernVBertDummyInputsBuilder(
BaseDummyInputsBuilder[ColModernVBertProcessingInfo],
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
class ColModernVBertMultiModalProcessor(
BaseMultiModalProcessor[ColModernVBertProcessingInfo],
):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
text_encoding = tokenizer(
prompt,
return_tensors="pt",
**tok_kwargs,
)
result = BatchFeature(data=dict(text_encoding))
images = mm_data.get("images")
if images:
from transformers import Idefics3ImageProcessor
image_processor = Idefics3ImageProcessor.from_pretrained(
self.info.ctx.model_config.model,
revision=self.info.ctx.model_config.revision,
)
image_outputs = image_processor(
images=images,
do_image_splitting=False,
return_tensors="pt",
)
result.update(image_outputs)
return result
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
config = self.info.get_hf_config()
image_token_id = config.image_token_id
num_tokens = config.image_seq_len
def get_replacement(item_idx: int):
return [image_token_id] * num_tokens
return [
PromptReplacement(
modality="image",
target=PromptIndexTargets.start(),
replacement=get_replacement,
),
]
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
@MULTIMODAL_REGISTRY.register_processor(
ColModernVBertMultiModalProcessor,
info=ColModernVBertProcessingInfo,
dummy_inputs=ColModernVBertDummyInputsBuilder,
)
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal):
"""ColModernVBERT multimodal late-interaction retrieval model.
Architecture:
Image -> SiglipVisionModel -> ColModernVBertConnector
Text -> ModernBertEmbeddings → [merge] → ModernBertLayers → norm
custom_text_proj → L2 norm
per-token 128-d embeddings
"""
is_pooling_model = True
supports_late_interaction: ClassVar[Literal[True]] = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: ColModernVBertConfig = vllm_config.model_config.hf_config
self.config = config
text_config = config.text_config
quant_config = vllm_config.quant_config
# --- Vision encoder (reuses SiglipVisionModel from siglip.py) ---
self.vision_model = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
# --- Connector (pixel shuffle + linear projection) ---
self.connector = ColModernVBertConnector(config)
# --- Text encoder (built from ModernBERT components directly) ---
# We build the components individually rather than wrapping
# ``ModernBertModel`` because ``ModernBertEncoderLayer`` reads
# ``vllm_config.model_config.hf_config`` which would be
# ``ColModernVBertConfig``, not ``ModernBertConfig``.
self.text_embeddings = ModernBertEmbeddings(text_config)
self.text_layers = nn.ModuleList(
[
ModernBertLayer(
config=text_config,
layer_id=i,
prefix=f"{prefix}.text_layers.{i}",
)
for i in range(text_config.num_hidden_layers)
]
)
self.text_final_norm = nn.LayerNorm(
text_config.hidden_size,
eps=text_config.norm_eps,
bias=text_config.norm_bias,
)
# --- ColBERT projection (768 -> 128, with bias) ---
self.custom_text_proj = nn.Linear(
text_config.hidden_size,
config.embedding_dim,
bias=True,
dtype=vllm_config.model_config.head_dtype,
)
# --- Pooler (applies projection + L2 normalize) ---
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = pooler_for_token_embed(
pooler_config,
projector=self.custom_text_proj,
)
# ---- multimodal ---------------------------------------------------------
def _get_image_features(
self,
pixel_values: torch.Tensor,
) -> torch.Tensor:
# Idefics3ImageProcessor may return (batch, tiles, C, H, W);
# flatten to (batch*tiles, C, H, W) for SiglipVisionModel.
if pixel_values.dim() == 5:
b, t, c, h, w = pixel_values.shape
pixel_values = pixel_values.reshape(b * t, c, h, w)
vision_outputs = self.vision_model(
pixel_values.to(dtype=self.vision_model.dtype),
)
return self.connector(vision_outputs)
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return []
assert isinstance(pixel_values, torch.Tensor)
image_features = self._get_image_features(pixel_values)
return list(image_features)
# ---- forward ------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
hidden_states = self.text_embeddings(input_ids, inputs_embeds=inputs_embeds)
for layer in self.text_layers:
hidden_states = layer(hidden_states, positions)
return self.text_final_norm(hidden_states)
# ---- weight loading -----------------------------------------------------
# Checkpoint prefix → vLLM param prefix.
# More-specific prefixes must appear before shorter ones.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.text_model.layers.": "text_layers.",
"model.text_model.embeddings.": "text_embeddings.",
"model.text_model.final_norm.": "text_final_norm.",
"model.connector.modality_projection.": "connector.",
"model.custom_text_proj.": "custom_text_proj.",
"model.vision_model.": "vision_model.vision_model.",
"model.": "",
},
)
# Checkpoint names for DecoupledEmbedding parts
_BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight"
_EXTRA_EMB = (
"model.text_model.embeddings.tok_embeddings.additional_embedding.weight"
)
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
# DecoupledEmbedding requires concatenating base + additional
# embedding tensors before loading, so we extract them first.
base_embedding_weight: torch.Tensor | None = None
additional_embedding_weight: torch.Tensor | None = None
remaining: list[tuple[str, torch.Tensor]] = []
for name, tensor in weights:
if name == self._BASE_EMB:
base_embedding_weight = tensor
elif name == self._EXTRA_EMB:
additional_embedding_weight = tensor
else:
remaining.append((name, tensor))
# Load all non-embedding weights via AutoWeightsLoader
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(
remaining,
mapper=self.hf_to_vllm_mapper,
)
# Concatenate and load DecoupledEmbedding weights
if base_embedding_weight is not None:
combined = base_embedding_weight
if additional_embedding_weight is not None:
combined = torch.cat(
[base_embedding_weight, additional_embedding_weight],
dim=0,
)
param_name = "text_embeddings.tok_embeddings.weight"
params_dict = dict(self.named_parameters())
if param_name in params_dict:
param = params_dict[param_name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
weight_loader(param, combined)
loaded_params.add(param_name)
elif additional_embedding_weight is not None:
raise ValueError(
"Found 'text_model.embeddings.tok_embeddings"
".additional_embedding.weight' but not "
"'text_model.embeddings.tok_embeddings.weight'"
)
# The pooler wraps ``custom_text_proj`` as its head projector.
# Mark those params as loaded under the pooler path too.
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
head = self.pooler.head
projector = getattr(head, "projector", None)
if projector is not None and isinstance(projector, nn.Module):
for pname, _ in projector.named_parameters():
loaded_params.add(f"pooler.head.projector.{pname}")
return loaded_params

View File

@@ -248,6 +248,7 @@ _EMBEDDING_MODELS = {
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
# [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
"LlavaNextForConditionalGeneration": (
"llava_next",
"LlavaNextForConditionalGeneration",

View File

@@ -74,6 +74,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
afmoe="AfmoeConfig",
bagel="BagelConfig",
chatglm="ChatGLMConfig",
colmodernvbert="ColModernVBertConfig",
colqwen3="ColQwen3Config",
ops_colqwen3="OpsColQwen3Config",
qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig",

View File

@@ -18,6 +18,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"AfmoeConfig": "vllm.transformers_utils.configs.afmoe",
"BagelConfig": "vllm.transformers_utils.configs.bagel",
"ChatGLMConfig": "vllm.transformers_utils.configs.chatglm",
"ColModernVBertConfig": "vllm.transformers_utils.configs.colmodernvbert",
"ColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
"OpsColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
"Qwen3VLNemotronEmbedConfig": "vllm.transformers_utils.configs.colqwen3",
@@ -71,6 +72,7 @@ __all__ = [
"AfmoeConfig",
"BagelConfig",
"ChatGLMConfig",
"ColModernVBertConfig",
"ColQwen3Config",
"OpsColQwen3Config",
"Qwen3VLNemotronEmbedConfig",

View File

@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Configuration for ColModernVBERT visual document retrieval model.
ColModernVBERT combines SigLIP vision encoder + ModernBERT text encoder
with a pixel shuffle connector and ColBERT-style 128-dim per-token embeddings.
Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
"""
from transformers import ModernBertConfig, PretrainedConfig, SiglipVisionConfig
class ColModernVBertConfig(PretrainedConfig):
model_type = "colmodernvbert"
def __init__(
self,
embedding_dim: int = 128,
vlm_config: dict | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
if vlm_config is None:
vlm_config = {}
# Top-level VLM fields
self.image_token_id = vlm_config.get("image_token_id", 50407)
self.pixel_shuffle_factor = vlm_config.get("pixel_shuffle_factor", 4)
self.hidden_size = vlm_config.get("hidden_size", 768)
additional_vocab_size = vlm_config.get("additional_vocab_size", 40)
# Text config (ModernBERT)
text_cfg = vlm_config.get("text_config", {})
base_vocab = text_cfg.get("vocab_size", 50368)
self.text_config = ModernBertConfig(
vocab_size=base_vocab + additional_vocab_size,
hidden_size=text_cfg.get("hidden_size", 768),
intermediate_size=text_cfg.get("intermediate_size", 1152),
num_hidden_layers=text_cfg.get("num_hidden_layers", 22),
num_attention_heads=text_cfg.get("num_attention_heads", 12),
mlp_bias=text_cfg.get("mlp_bias", False),
max_position_embeddings=vlm_config.get("max_position_embeddings", 8192),
)
# Vision config (SigLIP)
vis_cfg = vlm_config.get("vision_config", {})
self.vision_config = SiglipVisionConfig(
hidden_size=vis_cfg.get("embed_dim", 768),
image_size=vis_cfg.get("image_size", 512),
patch_size=vis_cfg.get("patch_size", 16),
num_hidden_layers=vis_cfg.get("num_hidden_layers", 12),
intermediate_size=vis_cfg.get("intermediate_size", 3072),
num_attention_heads=vis_cfg.get("num_attention_heads", 12),
)
@property
def image_seq_len(self) -> int:
ps = self.vision_config.image_size // self.vision_config.patch_size
return (ps * ps) // (self.pixel_shuffle_factor**2)
def get_text_config(self, **kwargs):
return self.text_config