diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7e685181f..bfb341f5b 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -828,6 +828,7 @@ The following table lists those that are tested in vLLM. | ------------ | ------ | ------ | ----------------- | -------------------- | ------------------------- | | `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | | `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | | +| `ColPaliForRetrieval` | ColPali | T / I | `vidore/colpali-v1.3-hf` | | | | `LlamaNemotronVLModel` | Llama Nemotron Embedding + SigLIP | T + I | `nvidia/llama-nemotron-embed-vl-1b-v2` | | | | `LlavaNextForConditionalGeneration`C | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | | `Phi3VForCausalLM`C | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | diff --git a/tests/models/multimodal/pooling/test_colpali.py b/tests/models/multimodal/pooling/test_colpali.py new file mode 100644 index 000000000..e7c373d10 --- /dev/null +++ b/tests/models/multimodal/pooling/test_colpali.py @@ -0,0 +1,323 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for ColPali late interaction model for multi-modal retrieval. + +ColPali is a multi-vector retrieval model based on PaliGemma backbone +(SigLIP + Gemma) with ColBERT-style late interaction scoring (MaxSim). +It produces per-token embeddings for both text and image inputs. +""" + +import base64 +from io import BytesIO + +import pytest +import torch +from PIL import Image + +from vllm.entrypoints.chat_utils import ( + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, +) +from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam + +from ....conftest import VllmRunner + +MODELS = [ + "vidore/colpali-v1.3-hf", +] + +EMBED_DIMS = { + "vidore/colpali-v1.3-hf": 128, +} + +TEXT_QUERIES = [ + "What is the capital of France?", + "Describe the contents of the document.", +] + +TEXT_DOCUMENTS = [ + "The capital of France is Paris.", + "This document contains important financial data.", +] + +DTYPE = "half" +GPU_MEMORY_UTILIZATION = 0.7 + + +def _make_base64_image( + width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0) +) -> str: + """Create a small solid-color PNG image and return its base64 data URI.""" + img = Image.new("RGB", (width, height), color) + buf = BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode() + return f"data:image/png;base64,{b64}" + + +def _make_image_mm_param( + image_uri: str, + text: str | None = None, +) -> ScoreMultiModalParam: + """Build a ScoreMultiModalParam containing an image (and optional text).""" + content: list = [ + ChatCompletionContentPartImageParam( + type="image_url", + image_url={"url": image_uri}, + ), + ] + if text is not None: + content.append( + ChatCompletionContentPartTextParam(type="text", text=text), + ) + return ScoreMultiModalParam(content=content) + + +def _run_token_embed_test( + vllm_runner: type[VllmRunner], + model: str, + *, + dtype: str, +) -> None: + """Verify per-token embedding shape and L2 normalization.""" + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + max_model_len=4096, + enforce_eager=True, + gpu_memory_utilization=GPU_MEMORY_UTILIZATION, + ) as vllm_model: + outputs = vllm_model.token_embed([TEXT_QUERIES[0]]) + + assert len(outputs) == 1 + emb = torch.tensor(outputs[0]) + # Token embeddings should be 2D: [num_tokens, embed_dim] + assert emb.dim() == 2 + assert emb.shape[1] == EMBED_DIMS[model] + assert emb.shape[0] > 1 + + # Verify L2 normalization + norms = torch.norm(emb, p=2, dim=-1) + torch.testing.assert_close( + norms, + torch.ones_like(norms), + rtol=1e-2, + atol=1e-2, + ) + + +def _run_late_interaction_test( + vllm_runner: type[VllmRunner], + model: str, + *, + dtype: str, +) -> None: + """Verify MaxSim scoring matches manual computation.""" + from vllm.entrypoints.pooling.score.utils import compute_maxsim_score + + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + max_model_len=4096, + enforce_eager=True, + gpu_memory_utilization=GPU_MEMORY_UTILIZATION, + ) as vllm_model: + q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]]) + d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]]) + + q_emb = torch.tensor(q_outputs[0]) + d_emb = torch.tensor(d_outputs[0]) + + manual_score = compute_maxsim_score(q_emb, d_emb).item() + + vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0]) + + assert len(vllm_scores) == 1 + assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01) + + +def _run_relevance_test( + vllm_runner: type[VllmRunner], + model: str, + *, + dtype: str, +) -> None: + """Verify that relevant documents score higher than irrelevant ones.""" + query = "What is machine learning?" + documents = [ + "Machine learning is a subset of artificial intelligence.", + "The weather forecast shows rain tomorrow.", + "Deep learning uses neural networks for complex tasks.", + ] + + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + max_model_len=4096, + enforce_eager=True, + gpu_memory_utilization=GPU_MEMORY_UTILIZATION, + ) as vllm_model: + scores = vllm_model.score(query, documents) + + assert len(scores) == 3 + assert scores[0] > scores[1], "ML doc should score higher than weather doc" + assert scores[2] > scores[1], "DL doc should score higher than weather doc" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [DTYPE]) +def test_colpali_token_embed( + vllm_runner, + model: str, + dtype: str, +) -> None: + _run_token_embed_test(vllm_runner, model, dtype=dtype) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [DTYPE]) +def test_colpali_late_interaction_scoring( + vllm_runner, + model: str, + dtype: str, +) -> None: + _run_late_interaction_test(vllm_runner, model, dtype=dtype) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [DTYPE]) +def test_colpali_relevance_ordering( + vllm_runner, + model: str, + dtype: str, +) -> None: + _run_relevance_test(vllm_runner, model, dtype=dtype) + + +# ── Multimodal scoring tests ──────────────────────────────── + + +def _run_multimodal_text_query_image_docs_test( + vllm_runner: type[VllmRunner], + model: str, + *, + dtype: str, +) -> None: + """Score a text query against image documents via the multimodal path.""" + red_image = _make_base64_image(64, 64, color=(255, 0, 0)) + blue_image = _make_base64_image(64, 64, color=(0, 0, 255)) + + query = "Describe the red object" + image_docs = [ + _make_image_mm_param(red_image), + _make_image_mm_param(blue_image), + ] + + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + max_model_len=4096, + enforce_eager=True, + gpu_memory_utilization=GPU_MEMORY_UTILIZATION, + ) as vllm_model: + scores = vllm_model.llm.score(query, image_docs) + + assert len(scores) == 2 + for s in scores: + assert isinstance(s.outputs.score, float) + + +def _run_multimodal_mixed_docs_test( + vllm_runner: type[VllmRunner], + model: str, + *, + dtype: str, +) -> None: + """Score a text query against a mix of text and image documents.""" + red_image = _make_base64_image(64, 64, color=(255, 0, 0)) + + query = "What is the capital of France?" + documents: list = [ + "The capital of France is Paris.", + _make_image_mm_param(red_image), + ] + + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + max_model_len=4096, + enforce_eager=True, + gpu_memory_utilization=GPU_MEMORY_UTILIZATION, + ) as vllm_model: + scores = vllm_model.llm.score(query, documents) + + assert len(scores) == 2 + for s in scores: + assert isinstance(s.outputs.score, float) + # Text document about France should score higher than a random image + assert scores[0].outputs.score > scores[1].outputs.score + + +def _run_multimodal_image_query_text_docs_test( + vllm_runner: type[VllmRunner], + model: str, + *, + dtype: str, +) -> None: + """Score an image query against text documents.""" + red_image = _make_base64_image(64, 64, color=(255, 0, 0)) + image_query = _make_image_mm_param(red_image, text="red color") + + documents = [ + "A bright red sports car.", + "The weather forecast shows rain tomorrow.", + ] + + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + max_model_len=4096, + enforce_eager=True, + gpu_memory_utilization=GPU_MEMORY_UTILIZATION, + ) as vllm_model: + scores = vllm_model.llm.score(image_query, documents) + + assert len(scores) == 2 + for s in scores: + assert isinstance(s.outputs.score, float) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [DTYPE]) +def test_colpali_multimodal_text_query_image_docs( + vllm_runner, + model: str, + dtype: str, +) -> None: + _run_multimodal_text_query_image_docs_test(vllm_runner, model, dtype=dtype) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [DTYPE]) +def test_colpali_multimodal_mixed_docs( + vllm_runner, + model: str, + dtype: str, +) -> None: + _run_multimodal_mixed_docs_test(vllm_runner, model, dtype=dtype) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [DTYPE]) +def test_colpali_multimodal_image_query_text_docs( + vllm_runner, + model: str, + dtype: str, +) -> None: + _run_multimodal_image_query_text_docs_test(vllm_runner, model, dtype=dtype) diff --git a/tests/models/registry.py b/tests/models/registry.py index f7733f3e5..afd630fa7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -631,6 +631,7 @@ _LATE_INTERACTION_EXAMPLE_MODELS = { "ColModernVBertForRetrieval": _HfExamplesInfo( "ModernVBERT/colmodernvbert-merged", ), + "ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"), "ColQwen3": _HfExamplesInfo( "TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True ), diff --git a/vllm/model_executor/models/colpali.py b/vllm/model_executor/models/colpali.py new file mode 100644 index 000000000..18317c0aa --- /dev/null +++ b/vllm/model_executor/models/colpali.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +ColPali late interaction model for multi-modal retrieval and reranking. + +ColPali extends PaliGemma with a ColBERT-style late interaction head, +producing per-token embeddings for both text and image inputs. It uses +MaxSim scoring for retrieval/reranking tasks. + +This model supports the "token_embed" pooling task and is designed for +multi-vector retrieval of documents containing both text and images. + +Reference: https://arxiv.org/abs/2407.01449 (ColPali) +Based on: PaliGemma backbone (SigLIP + Gemma) with custom text projection + +Target models: +- vidore/colpali-v1.3-hf +""" + +from collections.abc import Iterable, Mapping + +import torch +import torch.nn as nn +from transformers import BatchFeature, PaliGemmaProcessor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY + +from .interfaces import SupportsLateInteraction +from .interfaces_base import default_pooling_type +from .paligemma import ( + PaliGemmaDummyInputsBuilder, + PaliGemmaForConditionalGeneration, + PaliGemmaMultiModalProcessor, + PaliGemmaProcessingInfo, +) +from .utils import AutoWeightsLoader, WeightsMapper + + +class ColPaliProcessingInfo(PaliGemmaProcessingInfo): + """Processing info for ColPali models. + + ColPali models use a custom HuggingFace config (ColPaliConfig) that is + not an instance of PaliGemmaConfig. We override get_hf_config() and + get_hf_processor() to skip the strict type check. + """ + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object) -> PaliGemmaProcessor: + # Force standard PaliGemmaProcessor even when trust_remote_code=True. + return self.ctx.get_hf_processor(PaliGemmaProcessor, **kwargs) + + +class ColPaliMultiModalProcessor(PaliGemmaMultiModalProcessor): + """Multimodal processor for ColPali.""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + # The ColPali tokenizer_config.json ships with a small default + # max_length (50) that truncates the 1024 image tokens inserted + # by PaliGemmaProcessor, causing a token-count mismatch. + # vLLM enforces its own max_model_len, so we disable HF + # truncation to keep all image + text tokens intact. + tok_kwargs = dict(tok_kwargs, truncation=False) + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + +@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL") +@MULTIMODAL_REGISTRY.register_processor( + ColPaliMultiModalProcessor, + info=ColPaliProcessingInfo, + dummy_inputs=PaliGemmaDummyInputsBuilder, +) +class ColPaliModel( + PaliGemmaForConditionalGeneration, + SupportsLateInteraction, +): + """ColPali late interaction model for multi-modal retrieval/reranking. + + This model extends PaliGemmaForConditionalGeneration with a ColBERT-style + linear projection layer for per-token embeddings. It supports: + - "token_embed" task: Per-token embeddings for late interaction scoring + + The model produces L2-normalized per-token embeddings by: + 1. Running the PaliGemma backbone (vision + language) to get hidden states + 2. Projecting hidden states through a linear layer (hidden_size -> embed_dim) + 3. L2-normalizing the projected embeddings + """ + + # Mark this as a pooling model so vLLM routes to pooler path + is_pooling_model = True + + # Override hf_to_vllm_mapper to handle ColPali weight naming. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # HF transformers checkpoint (vidore/colpali-v1.3-hf) + # Weights: vlm.vision_tower.*, vlm.language_model.*, + # vlm.multi_modal_projector.* + "vlm.vision_tower.": "vision_tower.", + "vlm.language_model.": "language_model.", + "vlm.multi_modal_projector.": "multi_modal_projector.", + # colpali-engine checkpoint naming + "model.vision_tower.": "vision_tower.", + "model.language_model.": "language_model.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + head_dtype = vllm_config.model_config.head_dtype + + hidden_size = getattr(config, "hidden_size", None) + if hidden_size is None and hasattr(config, "text_config"): + hidden_size = config.text_config.hidden_size + if hidden_size is None: + raise ValueError( + "Unable to determine text hidden size from config. " + "Expected 'hidden_size' or 'text_config.hidden_size'." + ) + self._proj_hidden_size = hidden_size + + # ColPali uses embedding_dim=128, but also check other naming variants + self.embed_dim: int | None = ( + getattr(config, "embedding_dim", None) + or getattr(config, "embed_dim", None) + or getattr(config, "dim", None) + or getattr(config, "projection_dim", None) + or getattr(config, "colbert_dim", None) + ) + + # Build the projection layer if embed_dim is known + if self.embed_dim is not None: + self.custom_text_proj = nn.Linear( + hidden_size, + self.embed_dim, + bias=False, + dtype=head_dtype, + ) + else: + # Will be created during load_weights when dim is inferred + self.custom_text_proj = None + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler = pooler_for_token_embed( + pooler_config, + projector=self.custom_text_proj, + ) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors=None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + return super().forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + # Names used for the projection layer across different ColPali variants + _PROJ_LAYER_NAMES = { + "custom_text_proj", # vLLM internal naming + "embedding_proj_layer", # colpali-engine / HF naming + } + + def _is_proj_weight(self, name: str) -> bool: + """Check if a weight name belongs to the projection layer.""" + return any(proj_name in name for proj_name in self._PROJ_LAYER_NAMES) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with special handling for ColPali projection layer.""" + weights_list = list(weights) + proj_weights: list[tuple[str, torch.Tensor]] = [] + model_weights: list[tuple[str, torch.Tensor]] = [] + + for name, weight in weights_list: + if self._is_proj_weight(name): + proj_weights.append((name, weight)) + else: + model_weights.append((name, weight)) + + loader = AutoWeightsLoader(self) + loaded = loader.load_weights(model_weights, mapper=self.hf_to_vllm_mapper) + + if proj_weights: + model_dtype = next(self.language_model.parameters()).dtype + model_device = next(self.language_model.parameters()).device + + for name, weight in proj_weights: + if self.embed_dim is None and "weight" in name: + self.embed_dim = weight.shape[0] + has_bias = any("bias" in n for n, _ in proj_weights) + self.custom_text_proj = nn.Linear( + self._proj_hidden_size, + self.embed_dim, + bias=has_bias, + dtype=model_dtype, + ) + self.custom_text_proj.to(model_device) + + if self.custom_text_proj is not None: + param_name = name.split(".")[-1] + param = getattr(self.custom_text_proj, param_name, None) + if param is not None: + weight = weight.to(device=param.device, dtype=param.dtype) + default_weight_loader(param, weight) + loaded.add(f"custom_text_proj.{param_name}") + + # Update pooler projector for the lazy-creation path + self.pooler.head.projector = self.custom_text_proj + + # Mark pooler projector params as loaded + if hasattr(self, "pooler") and hasattr(self.pooler, "head"): + head = self.pooler.head + projector = getattr(head, "projector", None) + if projector is not None and isinstance(projector, nn.Module): + for pname, _ in projector.named_parameters(): + loaded.add(f"pooler.head.projector.{pname}") + + return loaded diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d5d3bd265..5fd64c7cb 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -247,6 +247,7 @@ _EMBEDDING_MODELS = { "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), # [Multimodal] "CLIPModel": ("clip", "CLIPEmbeddingModel"), + "ColPaliForRetrieval": ("colpali", "ColPaliModel"), "LlavaNextForConditionalGeneration": ( "llava_next", "LlavaNextForConditionalGeneration", diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index 0064cc6d6..af9fc77f1 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -33,6 +33,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "colpali": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", "deepseek_ocr2": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f03de6015..5aa984515 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -78,6 +78,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( bagel="BagelConfig", chatglm="ChatGLMConfig", colmodernvbert="ColModernVBertConfig", + colpali="ColPaliConfig", colqwen3="ColQwen3Config", ops_colqwen3="OpsColQwen3Config", qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 7902515e2..a19a5ec0f 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -20,6 +20,7 @@ _CLASS_TO_MODULE: dict[str, str] = { "BagelConfig": "vllm.transformers_utils.configs.bagel", "ChatGLMConfig": "vllm.transformers_utils.configs.chatglm", "ColModernVBertConfig": "vllm.transformers_utils.configs.colmodernvbert", + "ColPaliConfig": "vllm.transformers_utils.configs.colpali", "ColQwen3Config": "vllm.transformers_utils.configs.colqwen3", "OpsColQwen3Config": "vllm.transformers_utils.configs.colqwen3", "Qwen3VLNemotronEmbedConfig": "vllm.transformers_utils.configs.colqwen3", @@ -76,6 +77,7 @@ __all__ = [ "BagelConfig", "ChatGLMConfig", "ColModernVBertConfig", + "ColPaliConfig", "ColQwen3Config", "OpsColQwen3Config", "Qwen3VLNemotronEmbedConfig", diff --git a/vllm/transformers_utils/configs/colpali.py b/vllm/transformers_utils/configs/colpali.py new file mode 100644 index 000000000..f64aa7564 --- /dev/null +++ b/vllm/transformers_utils/configs/colpali.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +ColPali configuration that extends PaliGemmaConfig with embedding projection +fields. This allows ColPali models to be loaded without trust_remote_code +by mapping their custom model_type (colpali) to a standard config class +that vLLM understands. + +Supported model_types: +- colpali (vidore/colpali-v1.3-hf) +""" + +from transformers import PaliGemmaConfig + + +class ColPaliConfig(PaliGemmaConfig): + """Configuration class for ColPali models. + + Extends PaliGemmaConfig with additional fields used by ColPali variants + for the embedding projection layer. + """ + + model_type = "colpali" + + def __init__( + self, + embedding_dim: int | None = None, + embed_dim: int | None = None, + dim: int | None = None, + projection_dim: int | None = None, + colbert_dim: int | None = None, + pooling: str | None = None, + vlm_config: dict | None = None, + **kwargs, + ): + # Store embedding projection config fields + self.embedding_dim = embedding_dim + self.embed_dim = embed_dim + self.dim = dim + self.projection_dim = projection_dim + self.colbert_dim = colbert_dim + self.pooling = pooling + + # The HF checkpoint nests PaliGemma config inside "vlm_config". + # Flatten it so PaliGemmaConfig receives vision_config, text_config, + # image_token_index, etc. directly. + # Use setdefault to avoid overwriting keys already set (e.g. + # model_type="colpali" would be clobbered by "paligemma" from + # vlm_config). + if vlm_config is not None: + vlm_dict = ( + vlm_config if isinstance(vlm_config, dict) else vlm_config.to_dict() + ) + _conflicting = {"model_type", "_name_or_path"} + for key, value in vlm_dict.items(): + if key not in _conflicting: + kwargs.setdefault(key, value) + + super().__init__(**kwargs)