[Model] Update ColModernVBERT to support latest HF checkpoint (#39307)
Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com>
This commit is contained in:
@@ -15,10 +15,6 @@ from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
|
||||
COLBERT_DIM = 128
|
||||
DTYPE = "half"
|
||||
# Fixme:
|
||||
# Update colmodernvbert code to support the latest HF version
|
||||
# and remove revision set.
|
||||
REVISION = "4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -30,7 +26,6 @@ def test_colmodernvbert_text_token_embed(vllm_runner):
|
||||
"""Text query produces per-token embeddings with shape (seq_len, 128)."""
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -54,7 +49,6 @@ def test_colmodernvbert_text_relevance_ordering(vllm_runner):
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -72,7 +66,6 @@ def test_colmodernvbert_text_late_interaction(vllm_runner):
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -99,7 +92,6 @@ def test_colmodernvbert_image_token_embed(vllm_runner, image_assets):
|
||||
"""Image input produces per-token embeddings including vision tokens."""
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
|
||||
@@ -648,7 +648,6 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
|
||||
# [Multimodal]
|
||||
"ColModernVBertForRetrieval": _HfExamplesInfo(
|
||||
"ModernVBERT/colmodernvbert-merged",
|
||||
revision="4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee",
|
||||
),
|
||||
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
|
||||
"ColQwen3": _HfExamplesInfo(
|
||||
|
||||
@@ -18,7 +18,6 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs import MultiModalDataDict
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFieldConfig,
|
||||
@@ -358,70 +357,23 @@ class ColModernVBertForRetrieval(
|
||||
"model.text_model.layers.": "text_layers.",
|
||||
"model.text_model.embeddings.": "text_embeddings.",
|
||||
"model.text_model.final_norm.": "text_final_norm.",
|
||||
"model.connector.modality_projection.": "connector.",
|
||||
"model.connector.modality_projection.": "connector.proj.",
|
||||
"model.custom_text_proj.": "custom_text_proj.",
|
||||
"model.vision_model.": "vision_model.vision_model.",
|
||||
"model.vision_model.vision_model.": "vision_model.vision_model.",
|
||||
"model.": "",
|
||||
},
|
||||
)
|
||||
|
||||
# Checkpoint names for DecoupledEmbedding parts
|
||||
_BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight"
|
||||
_EXTRA_EMB = (
|
||||
"model.text_model.embeddings.tok_embeddings.additional_embedding.weight"
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> set[str]:
|
||||
# DecoupledEmbedding requires concatenating base + additional
|
||||
# embedding tensors before loading, so we extract them first.
|
||||
base_embedding_weight: torch.Tensor | None = None
|
||||
additional_embedding_weight: torch.Tensor | None = None
|
||||
remaining: list[tuple[str, torch.Tensor]] = []
|
||||
|
||||
for name, tensor in weights:
|
||||
if name == self._BASE_EMB:
|
||||
base_embedding_weight = tensor
|
||||
elif name == self._EXTRA_EMB:
|
||||
additional_embedding_weight = tensor
|
||||
else:
|
||||
remaining.append((name, tensor))
|
||||
|
||||
# Load all non-embedding weights via AutoWeightsLoader
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(
|
||||
remaining,
|
||||
weights,
|
||||
mapper=self.hf_to_vllm_mapper,
|
||||
)
|
||||
|
||||
# Concatenate and load DecoupledEmbedding weights
|
||||
if base_embedding_weight is not None:
|
||||
combined = base_embedding_weight
|
||||
if additional_embedding_weight is not None:
|
||||
combined = torch.cat(
|
||||
[base_embedding_weight, additional_embedding_weight],
|
||||
dim=0,
|
||||
)
|
||||
param_name = "text_embeddings.tok_embeddings.weight"
|
||||
params_dict = dict(self.named_parameters())
|
||||
if param_name in params_dict:
|
||||
param = params_dict[param_name]
|
||||
weight_loader = getattr(
|
||||
param,
|
||||
"weight_loader",
|
||||
default_weight_loader,
|
||||
)
|
||||
weight_loader(param, combined)
|
||||
loaded_params.add(param_name)
|
||||
elif additional_embedding_weight is not None:
|
||||
raise ValueError(
|
||||
"Found 'text_model.embeddings.tok_embeddings"
|
||||
".additional_embedding.weight' but not "
|
||||
"'text_model.embeddings.tok_embeddings.weight'"
|
||||
)
|
||||
|
||||
# The pooler wraps ``custom_text_proj`` as its head projector.
|
||||
# Mark those params as loaded under the pooler path too.
|
||||
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
|
||||
|
||||
@@ -82,7 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
bagel="BagelConfig",
|
||||
umm="CheersConfig",
|
||||
chatglm="ChatGLMConfig",
|
||||
colmodernvbert="ColModernVBertConfig",
|
||||
modernvbert="ColModernVBertConfig",
|
||||
colpali="ColPaliConfig",
|
||||
colqwen3="ColQwen3Config",
|
||||
ops_colqwen3="OpsColQwen3Config",
|
||||
|
||||
@@ -17,43 +17,41 @@ class ColModernVBertConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
vlm_config: dict | None = None,
|
||||
image_token_id: int = 50407,
|
||||
pixel_shuffle_factor: int = 4,
|
||||
text_config: dict | None = None,
|
||||
vision_config: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.embedding_dim = embedding_dim
|
||||
self.image_token_id = image_token_id
|
||||
self.pixel_shuffle_factor = pixel_shuffle_factor
|
||||
|
||||
if vlm_config is None:
|
||||
vlm_config = {}
|
||||
text_config = text_config or {}
|
||||
self.hidden_size = text_config.get("hidden_size", 768)
|
||||
|
||||
# Top-level VLM fields
|
||||
self.image_token_id = vlm_config.get("image_token_id", 50407)
|
||||
self.pixel_shuffle_factor = vlm_config.get("pixel_shuffle_factor", 4)
|
||||
self.hidden_size = vlm_config.get("hidden_size", 768)
|
||||
additional_vocab_size = vlm_config.get("additional_vocab_size", 40)
|
||||
|
||||
# Text config (ModernBERT)
|
||||
text_cfg = vlm_config.get("text_config", {})
|
||||
base_vocab = text_cfg.get("vocab_size", 50368)
|
||||
self.text_config = ModernBertConfig(
|
||||
vocab_size=base_vocab + additional_vocab_size,
|
||||
hidden_size=text_cfg.get("hidden_size", 768),
|
||||
intermediate_size=text_cfg.get("intermediate_size", 1152),
|
||||
num_hidden_layers=text_cfg.get("num_hidden_layers", 22),
|
||||
num_attention_heads=text_cfg.get("num_attention_heads", 12),
|
||||
mlp_bias=text_cfg.get("mlp_bias", False),
|
||||
max_position_embeddings=vlm_config.get("max_position_embeddings", 8192),
|
||||
vocab_size=text_config.get("vocab_size", 50408),
|
||||
hidden_size=text_config.get("hidden_size", 768),
|
||||
intermediate_size=text_config.get("intermediate_size", 1152),
|
||||
num_hidden_layers=text_config.get("num_hidden_layers", 22),
|
||||
num_attention_heads=text_config.get("num_attention_heads", 12),
|
||||
mlp_bias=text_config.get("mlp_bias", False),
|
||||
max_position_embeddings=text_config.get("max_position_embeddings", 8192),
|
||||
)
|
||||
|
||||
# Vision config (SigLIP)
|
||||
vis_cfg = vlm_config.get("vision_config", {})
|
||||
vision_config = vision_config or {}
|
||||
self.vision_config = SiglipVisionConfig(
|
||||
hidden_size=vis_cfg.get("embed_dim", 768),
|
||||
image_size=vis_cfg.get("image_size", 512),
|
||||
patch_size=vis_cfg.get("patch_size", 16),
|
||||
num_hidden_layers=vis_cfg.get("num_hidden_layers", 12),
|
||||
intermediate_size=vis_cfg.get("intermediate_size", 3072),
|
||||
num_attention_heads=vis_cfg.get("num_attention_heads", 12),
|
||||
hidden_size=vision_config.get("hidden_size", 768),
|
||||
image_size=vision_config.get("image_size", 512),
|
||||
patch_size=vision_config.get("patch_size", 16),
|
||||
num_hidden_layers=vision_config.get("num_hidden_layers", 12),
|
||||
intermediate_size=vision_config.get("intermediate_size", 3072),
|
||||
num_attention_heads=vision_config.get("num_attention_heads", 12),
|
||||
)
|
||||
|
||||
# Ensure architectures is set so vLLM routes to our model class
|
||||
kwargs.setdefault("architectures", ["ColModernVBertForRetrieval"])
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user