[New Model] Add ColModernVBERT (#34558)
Signed-off-by: Athrael Soju <athrael.soju@gmail.com> Signed-off-by: athrael-soju <athrael-soju@users.noreply.github.com>
This commit is contained in:
@@ -821,6 +821,7 @@ The following table lists those that are tested in vLLM.
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
||||
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
|
||||
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
|
||||
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
|
||||
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |
|
||||
| `Qwen3VLForConditionalGeneration`<sup>C</sup> | Qwen3-VL | T + I + V | `Qwen/Qwen3-VL-Embedding-2B`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
166
examples/pooling/score/colmodernvbert_rerank_online.py
Normal file
166
examples/pooling/score/colmodernvbert_rerank_online.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example of using ColModernVBERT late interaction model for reranking.
|
||||
|
||||
ColModernVBERT is a multi-modal ColBERT-style model combining a SigLIP
|
||||
vision encoder with a ModernBERT text encoder. It produces per-token
|
||||
embeddings and uses MaxSim scoring for retrieval and reranking.
|
||||
Supports both text and image inputs.
|
||||
|
||||
Start the server with:
|
||||
vllm serve ModernVBERT/colmodernvbert-merged --max-model-len 8192
|
||||
|
||||
Then run this script:
|
||||
python colmodernvbert_rerank_online.py
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
MODEL = "ModernVBERT/colmodernvbert-merged"
|
||||
BASE_URL = "http://127.0.0.1:8000"
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/300px-PNG_transparency_demonstration_1.png" # noqa: E501
|
||||
|
||||
|
||||
def rerank_text():
|
||||
"""Text-only reranking via /rerank endpoint."""
|
||||
print("=" * 60)
|
||||
print("1. Text reranking (/rerank)")
|
||||
print("=" * 60)
|
||||
|
||||
data = {
|
||||
"model": MODEL,
|
||||
"query": "What is machine learning?",
|
||||
"documents": [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Python is a programming language.",
|
||||
"Deep learning uses neural networks for complex tasks.",
|
||||
"The weather today is sunny.",
|
||||
],
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print("\n Ranked documents (most relevant first):")
|
||||
for item in result["results"]:
|
||||
doc_idx = item["index"]
|
||||
score = item["relevance_score"]
|
||||
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
|
||||
else:
|
||||
print(f" Request failed: {response.status_code}")
|
||||
print(f" {response.text[:300]}")
|
||||
|
||||
|
||||
def score_text():
|
||||
"""Text-only scoring via /score endpoint."""
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("2. Text scoring (/score)")
|
||||
print("=" * 60)
|
||||
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of France is Paris.",
|
||||
"Berlin is the capital of Germany.",
|
||||
"Python is a programming language.",
|
||||
]
|
||||
|
||||
data = {
|
||||
"model": MODEL,
|
||||
"text_1": query,
|
||||
"text_2": documents,
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"\n Query: {query}\n")
|
||||
for item in result["data"]:
|
||||
idx = item["index"]
|
||||
score = item["score"]
|
||||
print(f" Doc {idx} (score={score:.4f}): {documents[idx]}")
|
||||
else:
|
||||
print(f" Request failed: {response.status_code}")
|
||||
print(f" {response.text[:300]}")
|
||||
|
||||
|
||||
def score_text_top_n():
|
||||
"""Text reranking with top_n filtering via /rerank endpoint."""
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("3. Text reranking with top_n=2 (/rerank)")
|
||||
print("=" * 60)
|
||||
|
||||
data = {
|
||||
"model": MODEL,
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
"The capital of France is Paris.",
|
||||
"Berlin is the capital of Germany.",
|
||||
"Python is a programming language.",
|
||||
"The Eiffel Tower is in Paris.",
|
||||
],
|
||||
"top_n": 2,
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"\n Top {data['top_n']} results:")
|
||||
for item in result["results"]:
|
||||
doc_idx = item["index"]
|
||||
score = item["relevance_score"]
|
||||
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
|
||||
else:
|
||||
print(f" Request failed: {response.status_code}")
|
||||
print(f" {response.text[:300]}")
|
||||
|
||||
|
||||
def rerank_multimodal():
|
||||
"""Multimodal reranking with text and image documents via /rerank."""
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("4. Multimodal reranking: text query vs image document (/rerank)")
|
||||
print("=" * 60)
|
||||
|
||||
data = {
|
||||
"model": MODEL,
|
||||
"query": "A colorful logo with transparency",
|
||||
"documents": [
|
||||
{"content": [{"type": "image_url", "image_url": {"url": IMAGE_URL}}]},
|
||||
"Python is a programming language.",
|
||||
"The weather today is sunny.",
|
||||
],
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print("\n Ranked documents (most relevant first):")
|
||||
labels = ["[image]", "Python doc", "Weather doc"]
|
||||
for item in result["results"]:
|
||||
doc_idx = item["index"]
|
||||
score = item["relevance_score"]
|
||||
print(f" [{score:.4f}] {labels[doc_idx]}")
|
||||
else:
|
||||
print(f" Request failed: {response.status_code}")
|
||||
print(f" {response.text[:300]}")
|
||||
|
||||
|
||||
def main():
|
||||
rerank_text()
|
||||
score_text()
|
||||
score_text_top_n()
|
||||
rerank_multimodal()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
115
tests/models/multimodal/pooling/test_colmodernvbert.py
Normal file
115
tests/models/multimodal/pooling/test_colmodernvbert.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for ColModernVBERT multimodal late-interaction model.
|
||||
|
||||
ColModernVBERT combines SigLIP vision encoder + ModernBERT text encoder
|
||||
with a pixel shuffle connector and ColBERT-style 128-dim per-token
|
||||
embeddings for visual document retrieval.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
|
||||
MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
|
||||
COLBERT_DIM = 128
|
||||
DTYPE = "half"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Text-only tests
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_colmodernvbert_text_token_embed(vllm_runner):
|
||||
"""Text query produces per-token embeddings with shape (seq_len, 128)."""
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.token_embed(["What is machine learning?"])
|
||||
|
||||
assert len(outputs) == 1
|
||||
emb = torch.tensor(outputs[0])
|
||||
assert emb.dim() == 2
|
||||
assert emb.shape[1] == COLBERT_DIM
|
||||
assert emb.shape[0] > 1
|
||||
|
||||
|
||||
def test_colmodernvbert_text_relevance_ordering(vllm_runner):
|
||||
"""Relevant documents score higher than irrelevant ones."""
|
||||
query = "What is machine learning?"
|
||||
documents = [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"The weather in Paris is mild in spring.",
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
scores = vllm_model.score(query, documents)
|
||||
|
||||
assert len(scores) == 2
|
||||
assert scores[0] > scores[1], "ML doc should score higher than weather doc"
|
||||
|
||||
|
||||
def test_colmodernvbert_text_late_interaction(vllm_runner):
|
||||
"""MaxSim scoring via vLLM matches manual computation."""
|
||||
query = "What is the capital of France?"
|
||||
doc = "The capital of France is Paris."
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
q_out = vllm_model.token_embed([query])
|
||||
d_out = vllm_model.token_embed([doc])
|
||||
|
||||
q_emb = torch.tensor(q_out[0])
|
||||
d_emb = torch.tensor(d_out[0])
|
||||
manual_score = compute_maxsim_score(q_emb, d_emb).item()
|
||||
|
||||
vllm_scores = vllm_model.score(query, doc)
|
||||
|
||||
assert len(vllm_scores) == 1
|
||||
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Image tests
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_colmodernvbert_image_token_embed(vllm_runner, image_assets):
|
||||
"""Image input produces per-token embeddings including vision tokens."""
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
image = image_assets[0].pil_image
|
||||
inputs = vllm_model.get_inputs(
|
||||
[""],
|
||||
images=[image],
|
||||
)
|
||||
req_outputs = vllm_model.llm.encode(
|
||||
inputs,
|
||||
pooling_task="token_embed",
|
||||
)
|
||||
outputs = [req_output.outputs.data for req_output in req_outputs]
|
||||
|
||||
assert len(outputs) == 1
|
||||
emb = torch.tensor(outputs[0])
|
||||
assert emb.dim() == 2
|
||||
assert emb.shape[1] == COLBERT_DIM
|
||||
# Should have at least the image tokens (64 after pixel shuffle)
|
||||
assert emb.shape[0] >= 64
|
||||
@@ -592,6 +592,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
),
|
||||
# [Multimodal]
|
||||
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
|
||||
"ColModernVBertForRetrieval": _HfExamplesInfo(
|
||||
"ModernVBERT/colmodernvbert-merged",
|
||||
),
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||
"Phi3VForCausalLM": _HfExamplesInfo(
|
||||
"TIGER-Lab/VLM2Vec-Full", trust_remote_code=True
|
||||
|
||||
430
vllm/model_executor/models/colmodernvbert.py
Normal file
430
vllm/model_executor/models/colmodernvbert.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""ColModernVBERT: multimodal late-interaction retrieval model.
|
||||
|
||||
Combines SigLIP vision encoder + ModernBERT text encoder with a pixel
|
||||
shuffle connector and ColBERT-style 128-dim per-token embeddings.
|
||||
|
||||
Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptIndexTargets,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .modernbert import ModernBertEmbeddings, ModernBertLayer
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connector: pixel shuffle + simple linear projection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ColModernVBertConnector(nn.Module):
|
||||
"""Pixel shuffle spatial reduction followed by a linear projection.
|
||||
|
||||
Reduces the vision encoder's token count by ``factor^2`` via pixel-shuffle
|
||||
spatial rearrangement, then projects the concatenated channels to the text
|
||||
encoder's hidden size with a single bias-free linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ColModernVBertConfig):
|
||||
super().__init__()
|
||||
self.pixel_shuffle_factor = config.pixel_shuffle_factor
|
||||
vision_hidden_size = config.vision_config.hidden_size
|
||||
input_size = vision_hidden_size * (self.pixel_shuffle_factor**2)
|
||||
output_size = config.hidden_size
|
||||
self.proj = nn.Linear(input_size, output_size, bias=False)
|
||||
|
||||
def pixel_shuffle(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Spatial rearrangement that reduces seq length by factor^2."""
|
||||
batch_size, seq_length, hidden_size = features.shape
|
||||
height = width = int(seq_length**0.5)
|
||||
factor = self.pixel_shuffle_factor
|
||||
|
||||
# Reshape to (B, H, W, C)
|
||||
features = features.view(batch_size, height, width, hidden_size)
|
||||
|
||||
# Reshape to (B, H/f, f, W/f, f, C)
|
||||
features = features.view(
|
||||
batch_size, height // factor, factor, width // factor, factor, hidden_size
|
||||
)
|
||||
|
||||
# Permute to (B, H/f, W/f, f, f, C)
|
||||
features = features.permute(0, 1, 3, 2, 4, 5)
|
||||
|
||||
# Reshape to (B, H/f, W/f, C * f^2)
|
||||
new_hidden_size = hidden_size * (factor**2)
|
||||
features = features.reshape(
|
||||
batch_size, height // factor, width // factor, new_hidden_size
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
features = self.pixel_shuffle(features)
|
||||
batch_size = features.shape[0]
|
||||
features = features.reshape(batch_size, -1, features.shape[-1])
|
||||
return self.proj(features)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multimodal processing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ColModernVBertProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self) -> ColModernVBertConfig:
|
||||
return self.ctx.get_hf_config(ColModernVBertConfig)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
config = self.get_hf_config()
|
||||
size = config.vision_config.image_size
|
||||
return ImageSize(width=size, height=size)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return self.get_hf_config().image_seq_len
|
||||
|
||||
|
||||
class ColModernVBertDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[ColModernVBertProcessingInfo],
|
||||
):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
mm_processor_kwargs: Mapping[str, object] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
return {
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class ColModernVBertMultiModalProcessor(
|
||||
BaseMultiModalProcessor[ColModernVBertProcessingInfo],
|
||||
):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
text_encoding = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
**tok_kwargs,
|
||||
)
|
||||
result = BatchFeature(data=dict(text_encoding))
|
||||
|
||||
images = mm_data.get("images")
|
||||
if images:
|
||||
from transformers import Idefics3ImageProcessor
|
||||
|
||||
image_processor = Idefics3ImageProcessor.from_pretrained(
|
||||
self.info.ctx.model_config.model,
|
||||
revision=self.info.ctx.model_config.revision,
|
||||
)
|
||||
image_outputs = image_processor(
|
||||
images=images,
|
||||
do_image_splitting=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
result.update(image_outputs)
|
||||
|
||||
return result
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
config = self.info.get_hf_config()
|
||||
image_token_id = config.image_token_id
|
||||
num_tokens = config.image_seq_len
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
return [image_token_id] * num_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.start(),
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
ColModernVBertMultiModalProcessor,
|
||||
info=ColModernVBertProcessingInfo,
|
||||
dummy_inputs=ColModernVBertDummyInputsBuilder,
|
||||
)
|
||||
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
||||
class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal):
|
||||
"""ColModernVBERT multimodal late-interaction retrieval model.
|
||||
|
||||
Architecture:
|
||||
Image -> SiglipVisionModel -> ColModernVBertConnector
|
||||
↓
|
||||
Text -> ModernBertEmbeddings → [merge] → ModernBertLayers → norm
|
||||
↓
|
||||
custom_text_proj → L2 norm
|
||||
↓
|
||||
per-token 128-d embeddings
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
supports_late_interaction: ClassVar[Literal[True]] = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: ColModernVBertConfig = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
text_config = config.text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# --- Vision encoder (reuses SiglipVisionModel from siglip.py) ---
|
||||
self.vision_model = SiglipVisionModel(
|
||||
config.vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
# --- Connector (pixel shuffle + linear projection) ---
|
||||
self.connector = ColModernVBertConnector(config)
|
||||
|
||||
# --- Text encoder (built from ModernBERT components directly) ---
|
||||
# We build the components individually rather than wrapping
|
||||
# ``ModernBertModel`` because ``ModernBertEncoderLayer`` reads
|
||||
# ``vllm_config.model_config.hf_config`` which would be
|
||||
# ``ColModernVBertConfig``, not ``ModernBertConfig``.
|
||||
self.text_embeddings = ModernBertEmbeddings(text_config)
|
||||
self.text_layers = nn.ModuleList(
|
||||
[
|
||||
ModernBertLayer(
|
||||
config=text_config,
|
||||
layer_id=i,
|
||||
prefix=f"{prefix}.text_layers.{i}",
|
||||
)
|
||||
for i in range(text_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.text_final_norm = nn.LayerNorm(
|
||||
text_config.hidden_size,
|
||||
eps=text_config.norm_eps,
|
||||
bias=text_config.norm_bias,
|
||||
)
|
||||
|
||||
# --- ColBERT projection (768 -> 128, with bias) ---
|
||||
self.custom_text_proj = nn.Linear(
|
||||
text_config.hidden_size,
|
||||
config.embedding_dim,
|
||||
bias=True,
|
||||
dtype=vllm_config.model_config.head_dtype,
|
||||
)
|
||||
|
||||
# --- Pooler (applies projection + L2 normalize) ---
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = pooler_for_token_embed(
|
||||
pooler_config,
|
||||
projector=self.custom_text_proj,
|
||||
)
|
||||
|
||||
# ---- multimodal ---------------------------------------------------------
|
||||
|
||||
def _get_image_features(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Idefics3ImageProcessor may return (batch, tiles, C, H, W);
|
||||
# flatten to (batch*tiles, C, H, W) for SiglipVisionModel.
|
||||
if pixel_values.dim() == 5:
|
||||
b, t, c, h, w = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(b * t, c, h, w)
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values.to(dtype=self.vision_model.dtype),
|
||||
)
|
||||
return self.connector(vision_outputs)
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is None:
|
||||
return []
|
||||
assert isinstance(pixel_values, torch.Tensor)
|
||||
image_features = self._get_image_features(pixel_values)
|
||||
return list(image_features)
|
||||
|
||||
# ---- forward ------------------------------------------------------------
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.text_embeddings(input_ids, inputs_embeds=inputs_embeds)
|
||||
|
||||
for layer in self.text_layers:
|
||||
hidden_states = layer(hidden_states, positions)
|
||||
|
||||
return self.text_final_norm(hidden_states)
|
||||
|
||||
# ---- weight loading -----------------------------------------------------
|
||||
|
||||
# Checkpoint prefix → vLLM param prefix.
|
||||
# More-specific prefixes must appear before shorter ones.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.text_model.layers.": "text_layers.",
|
||||
"model.text_model.embeddings.": "text_embeddings.",
|
||||
"model.text_model.final_norm.": "text_final_norm.",
|
||||
"model.connector.modality_projection.": "connector.",
|
||||
"model.custom_text_proj.": "custom_text_proj.",
|
||||
"model.vision_model.": "vision_model.vision_model.",
|
||||
"model.": "",
|
||||
},
|
||||
)
|
||||
|
||||
# Checkpoint names for DecoupledEmbedding parts
|
||||
_BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight"
|
||||
_EXTRA_EMB = (
|
||||
"model.text_model.embeddings.tok_embeddings.additional_embedding.weight"
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> set[str]:
|
||||
# DecoupledEmbedding requires concatenating base + additional
|
||||
# embedding tensors before loading, so we extract them first.
|
||||
base_embedding_weight: torch.Tensor | None = None
|
||||
additional_embedding_weight: torch.Tensor | None = None
|
||||
remaining: list[tuple[str, torch.Tensor]] = []
|
||||
|
||||
for name, tensor in weights:
|
||||
if name == self._BASE_EMB:
|
||||
base_embedding_weight = tensor
|
||||
elif name == self._EXTRA_EMB:
|
||||
additional_embedding_weight = tensor
|
||||
else:
|
||||
remaining.append((name, tensor))
|
||||
|
||||
# Load all non-embedding weights via AutoWeightsLoader
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(
|
||||
remaining,
|
||||
mapper=self.hf_to_vllm_mapper,
|
||||
)
|
||||
|
||||
# Concatenate and load DecoupledEmbedding weights
|
||||
if base_embedding_weight is not None:
|
||||
combined = base_embedding_weight
|
||||
if additional_embedding_weight is not None:
|
||||
combined = torch.cat(
|
||||
[base_embedding_weight, additional_embedding_weight],
|
||||
dim=0,
|
||||
)
|
||||
param_name = "text_embeddings.tok_embeddings.weight"
|
||||
params_dict = dict(self.named_parameters())
|
||||
if param_name in params_dict:
|
||||
param = params_dict[param_name]
|
||||
weight_loader = getattr(
|
||||
param,
|
||||
"weight_loader",
|
||||
default_weight_loader,
|
||||
)
|
||||
weight_loader(param, combined)
|
||||
loaded_params.add(param_name)
|
||||
elif additional_embedding_weight is not None:
|
||||
raise ValueError(
|
||||
"Found 'text_model.embeddings.tok_embeddings"
|
||||
".additional_embedding.weight' but not "
|
||||
"'text_model.embeddings.tok_embeddings.weight'"
|
||||
)
|
||||
|
||||
# The pooler wraps ``custom_text_proj`` as its head projector.
|
||||
# Mark those params as loaded under the pooler path too.
|
||||
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
|
||||
head = self.pooler.head
|
||||
projector = getattr(head, "projector", None)
|
||||
if projector is not None and isinstance(projector, nn.Module):
|
||||
for pname, _ in projector.named_parameters():
|
||||
loaded_params.add(f"pooler.head.projector.{pname}")
|
||||
|
||||
return loaded_params
|
||||
@@ -248,6 +248,7 @@ _EMBEDDING_MODELS = {
|
||||
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
|
||||
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
|
||||
"LlavaNextForConditionalGeneration": (
|
||||
"llava_next",
|
||||
"LlavaNextForConditionalGeneration",
|
||||
|
||||
@@ -74,6 +74,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
afmoe="AfmoeConfig",
|
||||
bagel="BagelConfig",
|
||||
chatglm="ChatGLMConfig",
|
||||
colmodernvbert="ColModernVBertConfig",
|
||||
colqwen3="ColQwen3Config",
|
||||
ops_colqwen3="OpsColQwen3Config",
|
||||
qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig",
|
||||
|
||||
@@ -18,6 +18,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
|
||||
"AfmoeConfig": "vllm.transformers_utils.configs.afmoe",
|
||||
"BagelConfig": "vllm.transformers_utils.configs.bagel",
|
||||
"ChatGLMConfig": "vllm.transformers_utils.configs.chatglm",
|
||||
"ColModernVBertConfig": "vllm.transformers_utils.configs.colmodernvbert",
|
||||
"ColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
|
||||
"OpsColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
|
||||
"Qwen3VLNemotronEmbedConfig": "vllm.transformers_utils.configs.colqwen3",
|
||||
@@ -71,6 +72,7 @@ __all__ = [
|
||||
"AfmoeConfig",
|
||||
"BagelConfig",
|
||||
"ChatGLMConfig",
|
||||
"ColModernVBertConfig",
|
||||
"ColQwen3Config",
|
||||
"OpsColQwen3Config",
|
||||
"Qwen3VLNemotronEmbedConfig",
|
||||
|
||||
65
vllm/transformers_utils/configs/colmodernvbert.py
Normal file
65
vllm/transformers_utils/configs/colmodernvbert.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Configuration for ColModernVBERT visual document retrieval model.
|
||||
|
||||
ColModernVBERT combines SigLIP vision encoder + ModernBERT text encoder
|
||||
with a pixel shuffle connector and ColBERT-style 128-dim per-token embeddings.
|
||||
|
||||
Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
|
||||
"""
|
||||
|
||||
from transformers import ModernBertConfig, PretrainedConfig, SiglipVisionConfig
|
||||
|
||||
|
||||
class ColModernVBertConfig(PretrainedConfig):
|
||||
model_type = "colmodernvbert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
vlm_config: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
if vlm_config is None:
|
||||
vlm_config = {}
|
||||
|
||||
# Top-level VLM fields
|
||||
self.image_token_id = vlm_config.get("image_token_id", 50407)
|
||||
self.pixel_shuffle_factor = vlm_config.get("pixel_shuffle_factor", 4)
|
||||
self.hidden_size = vlm_config.get("hidden_size", 768)
|
||||
additional_vocab_size = vlm_config.get("additional_vocab_size", 40)
|
||||
|
||||
# Text config (ModernBERT)
|
||||
text_cfg = vlm_config.get("text_config", {})
|
||||
base_vocab = text_cfg.get("vocab_size", 50368)
|
||||
self.text_config = ModernBertConfig(
|
||||
vocab_size=base_vocab + additional_vocab_size,
|
||||
hidden_size=text_cfg.get("hidden_size", 768),
|
||||
intermediate_size=text_cfg.get("intermediate_size", 1152),
|
||||
num_hidden_layers=text_cfg.get("num_hidden_layers", 22),
|
||||
num_attention_heads=text_cfg.get("num_attention_heads", 12),
|
||||
mlp_bias=text_cfg.get("mlp_bias", False),
|
||||
max_position_embeddings=vlm_config.get("max_position_embeddings", 8192),
|
||||
)
|
||||
|
||||
# Vision config (SigLIP)
|
||||
vis_cfg = vlm_config.get("vision_config", {})
|
||||
self.vision_config = SiglipVisionConfig(
|
||||
hidden_size=vis_cfg.get("embed_dim", 768),
|
||||
image_size=vis_cfg.get("image_size", 512),
|
||||
patch_size=vis_cfg.get("patch_size", 16),
|
||||
num_hidden_layers=vis_cfg.get("num_hidden_layers", 12),
|
||||
intermediate_size=vis_cfg.get("intermediate_size", 3072),
|
||||
num_attention_heads=vis_cfg.get("num_attention_heads", 12),
|
||||
)
|
||||
|
||||
@property
|
||||
def image_seq_len(self) -> int:
|
||||
ps = self.vision_config.image_size // self.vision_config.patch_size
|
||||
return (ps * ps) // (self.pixel_shuffle_factor**2)
|
||||
|
||||
def get_text_config(self, **kwargs):
|
||||
return self.text_config
|
||||
Reference in New Issue
Block a user