[Model] Update ColModernVBERT to support latest HF checkpoint (#39307)

Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com>
This commit is contained in:
Ilya Boytsov
2026-04-09 04:48:51 +02:00
committed by GitHub
parent 92fbec391b
commit d37b378762
5 changed files with 29 additions and 88 deletions

View File

@@ -15,10 +15,6 @@ from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
COLBERT_DIM = 128
DTYPE = "half"
# Fixme:
# Update colmodernvbert code to support the latest HF version
# and remove revision set.
REVISION = "4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee"
# -----------------------------------------------------------------------
@@ -30,7 +26,6 @@ def test_colmodernvbert_text_token_embed(vllm_runner):
"""Text query produces per-token embeddings with shape (seq_len, 128)."""
with vllm_runner(
MODEL_NAME,
revision=REVISION,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
@@ -54,7 +49,6 @@ def test_colmodernvbert_text_relevance_ordering(vllm_runner):
with vllm_runner(
MODEL_NAME,
revision=REVISION,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
@@ -72,7 +66,6 @@ def test_colmodernvbert_text_late_interaction(vllm_runner):
with vllm_runner(
MODEL_NAME,
revision=REVISION,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
@@ -99,7 +92,6 @@ def test_colmodernvbert_image_token_embed(vllm_runner, image_assets):
"""Image input produces per-token embeddings including vision tokens."""
with vllm_runner(
MODEL_NAME,
revision=REVISION,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,

View File

@@ -648,7 +648,6 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
# [Multimodal]
"ColModernVBertForRetrieval": _HfExamplesInfo(
"ModernVBERT/colmodernvbert-merged",
revision="4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee",
),
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
"ColQwen3": _HfExamplesInfo(

View File

@@ -18,7 +18,6 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
@@ -358,70 +357,23 @@ class ColModernVBertForRetrieval(
"model.text_model.layers.": "text_layers.",
"model.text_model.embeddings.": "text_embeddings.",
"model.text_model.final_norm.": "text_final_norm.",
"model.connector.modality_projection.": "connector.",
"model.connector.modality_projection.": "connector.proj.",
"model.custom_text_proj.": "custom_text_proj.",
"model.vision_model.": "vision_model.vision_model.",
"model.vision_model.vision_model.": "vision_model.vision_model.",
"model.": "",
},
)
# Checkpoint names for DecoupledEmbedding parts
_BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight"
_EXTRA_EMB = (
"model.text_model.embeddings.tok_embeddings.additional_embedding.weight"
)
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
# DecoupledEmbedding requires concatenating base + additional
# embedding tensors before loading, so we extract them first.
base_embedding_weight: torch.Tensor | None = None
additional_embedding_weight: torch.Tensor | None = None
remaining: list[tuple[str, torch.Tensor]] = []
for name, tensor in weights:
if name == self._BASE_EMB:
base_embedding_weight = tensor
elif name == self._EXTRA_EMB:
additional_embedding_weight = tensor
else:
remaining.append((name, tensor))
# Load all non-embedding weights via AutoWeightsLoader
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(
remaining,
weights,
mapper=self.hf_to_vllm_mapper,
)
# Concatenate and load DecoupledEmbedding weights
if base_embedding_weight is not None:
combined = base_embedding_weight
if additional_embedding_weight is not None:
combined = torch.cat(
[base_embedding_weight, additional_embedding_weight],
dim=0,
)
param_name = "text_embeddings.tok_embeddings.weight"
params_dict = dict(self.named_parameters())
if param_name in params_dict:
param = params_dict[param_name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
weight_loader(param, combined)
loaded_params.add(param_name)
elif additional_embedding_weight is not None:
raise ValueError(
"Found 'text_model.embeddings.tok_embeddings"
".additional_embedding.weight' but not "
"'text_model.embeddings.tok_embeddings.weight'"
)
# The pooler wraps ``custom_text_proj`` as its head projector.
# Mark those params as loaded under the pooler path too.
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):

View File

@@ -82,7 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
bagel="BagelConfig",
umm="CheersConfig",
chatglm="ChatGLMConfig",
colmodernvbert="ColModernVBertConfig",
modernvbert="ColModernVBertConfig",
colpali="ColPaliConfig",
colqwen3="ColQwen3Config",
ops_colqwen3="OpsColQwen3Config",

View File

@@ -17,43 +17,41 @@ class ColModernVBertConfig(PretrainedConfig):
def __init__(
self,
embedding_dim: int = 128,
vlm_config: dict | None = None,
image_token_id: int = 50407,
pixel_shuffle_factor: int = 4,
text_config: dict | None = None,
vision_config: dict | None = None,
**kwargs,
):
self.embedding_dim = embedding_dim
self.image_token_id = image_token_id
self.pixel_shuffle_factor = pixel_shuffle_factor
if vlm_config is None:
vlm_config = {}
text_config = text_config or {}
self.hidden_size = text_config.get("hidden_size", 768)
# Top-level VLM fields
self.image_token_id = vlm_config.get("image_token_id", 50407)
self.pixel_shuffle_factor = vlm_config.get("pixel_shuffle_factor", 4)
self.hidden_size = vlm_config.get("hidden_size", 768)
additional_vocab_size = vlm_config.get("additional_vocab_size", 40)
# Text config (ModernBERT)
text_cfg = vlm_config.get("text_config", {})
base_vocab = text_cfg.get("vocab_size", 50368)
self.text_config = ModernBertConfig(
vocab_size=base_vocab + additional_vocab_size,
hidden_size=text_cfg.get("hidden_size", 768),
intermediate_size=text_cfg.get("intermediate_size", 1152),
num_hidden_layers=text_cfg.get("num_hidden_layers", 22),
num_attention_heads=text_cfg.get("num_attention_heads", 12),
mlp_bias=text_cfg.get("mlp_bias", False),
max_position_embeddings=vlm_config.get("max_position_embeddings", 8192),
vocab_size=text_config.get("vocab_size", 50408),
hidden_size=text_config.get("hidden_size", 768),
intermediate_size=text_config.get("intermediate_size", 1152),
num_hidden_layers=text_config.get("num_hidden_layers", 22),
num_attention_heads=text_config.get("num_attention_heads", 12),
mlp_bias=text_config.get("mlp_bias", False),
max_position_embeddings=text_config.get("max_position_embeddings", 8192),
)
# Vision config (SigLIP)
vis_cfg = vlm_config.get("vision_config", {})
vision_config = vision_config or {}
self.vision_config = SiglipVisionConfig(
hidden_size=vis_cfg.get("embed_dim", 768),
image_size=vis_cfg.get("image_size", 512),
patch_size=vis_cfg.get("patch_size", 16),
num_hidden_layers=vis_cfg.get("num_hidden_layers", 12),
intermediate_size=vis_cfg.get("intermediate_size", 3072),
num_attention_heads=vis_cfg.get("num_attention_heads", 12),
hidden_size=vision_config.get("hidden_size", 768),
image_size=vision_config.get("image_size", 512),
patch_size=vision_config.get("patch_size", 16),
num_hidden_layers=vision_config.get("num_hidden_layers", 12),
intermediate_size=vision_config.get("intermediate_size", 3072),
num_attention_heads=vision_config.get("num_attention_heads", 12),
)
# Ensure architectures is set so vLLM routes to our model class
kwargs.setdefault("architectures", ["ColModernVBertForRetrieval"])
super().__init__(**kwargs)
@property