[Model] Add ColPali late interaction model for multi-modal retrieval (#36818)
Signed-off-by: Nikita Sukharev <kaonael@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -828,6 +828,7 @@ The following table lists those that are tested in vLLM.
|
||||
| ------------ | ------ | ------ | ----------------- | -------------------- | ------------------------- |
|
||||
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
|
||||
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
|
||||
| `ColPaliForRetrieval` | ColPali | T / I | `vidore/colpali-v1.3-hf` | | |
|
||||
| `LlamaNemotronVLModel` | Llama Nemotron Embedding + SigLIP | T + I | `nvidia/llama-nemotron-embed-vl-1b-v2` | | |
|
||||
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
|
||||
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |
|
||||
|
||||
323
tests/models/multimodal/pooling/test_colpali.py
Normal file
323
tests/models/multimodal/pooling/test_colpali.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for ColPali late interaction model for multi-modal retrieval.
|
||||
|
||||
ColPali is a multi-vector retrieval model based on PaliGemma backbone
|
||||
(SigLIP + Gemma) with ColBERT-style late interaction scoring (MaxSim).
|
||||
It produces per-token embeddings for both text and image inputs.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
|
||||
MODELS = [
|
||||
"vidore/colpali-v1.3-hf",
|
||||
]
|
||||
|
||||
EMBED_DIMS = {
|
||||
"vidore/colpali-v1.3-hf": 128,
|
||||
}
|
||||
|
||||
TEXT_QUERIES = [
|
||||
"What is the capital of France?",
|
||||
"Describe the contents of the document.",
|
||||
]
|
||||
|
||||
TEXT_DOCUMENTS = [
|
||||
"The capital of France is Paris.",
|
||||
"This document contains important financial data.",
|
||||
]
|
||||
|
||||
DTYPE = "half"
|
||||
GPU_MEMORY_UTILIZATION = 0.7
|
||||
|
||||
|
||||
def _make_base64_image(
|
||||
width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0)
|
||||
) -> str:
|
||||
"""Create a small solid-color PNG image and return its base64 data URI."""
|
||||
img = Image.new("RGB", (width, height), color)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def _make_image_mm_param(
|
||||
image_uri: str,
|
||||
text: str | None = None,
|
||||
) -> ScoreMultiModalParam:
|
||||
"""Build a ScoreMultiModalParam containing an image (and optional text)."""
|
||||
content: list = [
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url={"url": image_uri},
|
||||
),
|
||||
]
|
||||
if text is not None:
|
||||
content.append(
|
||||
ChatCompletionContentPartTextParam(type="text", text=text),
|
||||
)
|
||||
return ScoreMultiModalParam(content=content)
|
||||
|
||||
|
||||
def _run_token_embed_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Verify per-token embedding shape and L2 normalization."""
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
|
||||
|
||||
assert len(outputs) == 1
|
||||
emb = torch.tensor(outputs[0])
|
||||
# Token embeddings should be 2D: [num_tokens, embed_dim]
|
||||
assert emb.dim() == 2
|
||||
assert emb.shape[1] == EMBED_DIMS[model]
|
||||
assert emb.shape[0] > 1
|
||||
|
||||
# Verify L2 normalization
|
||||
norms = torch.norm(emb, p=2, dim=-1)
|
||||
torch.testing.assert_close(
|
||||
norms,
|
||||
torch.ones_like(norms),
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
def _run_late_interaction_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Verify MaxSim scoring matches manual computation."""
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
||||
) as vllm_model:
|
||||
q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
|
||||
d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]])
|
||||
|
||||
q_emb = torch.tensor(q_outputs[0])
|
||||
d_emb = torch.tensor(d_outputs[0])
|
||||
|
||||
manual_score = compute_maxsim_score(q_emb, d_emb).item()
|
||||
|
||||
vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0])
|
||||
|
||||
assert len(vllm_scores) == 1
|
||||
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
|
||||
|
||||
|
||||
def _run_relevance_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Verify that relevant documents score higher than irrelevant ones."""
|
||||
query = "What is machine learning?"
|
||||
documents = [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"The weather forecast shows rain tomorrow.",
|
||||
"Deep learning uses neural networks for complex tasks.",
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
||||
) as vllm_model:
|
||||
scores = vllm_model.score(query, documents)
|
||||
|
||||
assert len(scores) == 3
|
||||
assert scores[0] > scores[1], "ML doc should score higher than weather doc"
|
||||
assert scores[2] > scores[1], "DL doc should score higher than weather doc"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", [DTYPE])
|
||||
def test_colpali_token_embed(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
_run_token_embed_test(vllm_runner, model, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", [DTYPE])
|
||||
def test_colpali_late_interaction_scoring(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
_run_late_interaction_test(vllm_runner, model, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", [DTYPE])
|
||||
def test_colpali_relevance_ordering(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
_run_relevance_test(vllm_runner, model, dtype=dtype)
|
||||
|
||||
|
||||
# ── Multimodal scoring tests ────────────────────────────────
|
||||
|
||||
|
||||
def _run_multimodal_text_query_image_docs_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Score a text query against image documents via the multimodal path."""
|
||||
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
|
||||
blue_image = _make_base64_image(64, 64, color=(0, 0, 255))
|
||||
|
||||
query = "Describe the red object"
|
||||
image_docs = [
|
||||
_make_image_mm_param(red_image),
|
||||
_make_image_mm_param(blue_image),
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
||||
) as vllm_model:
|
||||
scores = vllm_model.llm.score(query, image_docs)
|
||||
|
||||
assert len(scores) == 2
|
||||
for s in scores:
|
||||
assert isinstance(s.outputs.score, float)
|
||||
|
||||
|
||||
def _run_multimodal_mixed_docs_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Score a text query against a mix of text and image documents."""
|
||||
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
|
||||
|
||||
query = "What is the capital of France?"
|
||||
documents: list = [
|
||||
"The capital of France is Paris.",
|
||||
_make_image_mm_param(red_image),
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
||||
) as vllm_model:
|
||||
scores = vllm_model.llm.score(query, documents)
|
||||
|
||||
assert len(scores) == 2
|
||||
for s in scores:
|
||||
assert isinstance(s.outputs.score, float)
|
||||
# Text document about France should score higher than a random image
|
||||
assert scores[0].outputs.score > scores[1].outputs.score
|
||||
|
||||
|
||||
def _run_multimodal_image_query_text_docs_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Score an image query against text documents."""
|
||||
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
|
||||
image_query = _make_image_mm_param(red_image, text="red color")
|
||||
|
||||
documents = [
|
||||
"A bright red sports car.",
|
||||
"The weather forecast shows rain tomorrow.",
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
||||
) as vllm_model:
|
||||
scores = vllm_model.llm.score(image_query, documents)
|
||||
|
||||
assert len(scores) == 2
|
||||
for s in scores:
|
||||
assert isinstance(s.outputs.score, float)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", [DTYPE])
|
||||
def test_colpali_multimodal_text_query_image_docs(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
_run_multimodal_text_query_image_docs_test(vllm_runner, model, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", [DTYPE])
|
||||
def test_colpali_multimodal_mixed_docs(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
_run_multimodal_mixed_docs_test(vllm_runner, model, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", [DTYPE])
|
||||
def test_colpali_multimodal_image_query_text_docs(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
_run_multimodal_image_query_text_docs_test(vllm_runner, model, dtype=dtype)
|
||||
@@ -631,6 +631,7 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
|
||||
"ColModernVBertForRetrieval": _HfExamplesInfo(
|
||||
"ModernVBERT/colmodernvbert-merged",
|
||||
),
|
||||
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
|
||||
"ColQwen3": _HfExamplesInfo(
|
||||
"TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True
|
||||
),
|
||||
|
||||
245
vllm/model_executor/models/colpali.py
Normal file
245
vllm/model_executor/models/colpali.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
ColPali late interaction model for multi-modal retrieval and reranking.
|
||||
|
||||
ColPali extends PaliGemma with a ColBERT-style late interaction head,
|
||||
producing per-token embeddings for both text and image inputs. It uses
|
||||
MaxSim scoring for retrieval/reranking tasks.
|
||||
|
||||
This model supports the "token_embed" pooling task and is designed for
|
||||
multi-vector retrieval of documents containing both text and images.
|
||||
|
||||
Reference: https://arxiv.org/abs/2407.01449 (ColPali)
|
||||
Based on: PaliGemma backbone (SigLIP + Gemma) with custom text projection
|
||||
|
||||
Target models:
|
||||
- vidore/colpali-v1.3-hf
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, PaliGemmaProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from .interfaces import SupportsLateInteraction
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .paligemma import (
|
||||
PaliGemmaDummyInputsBuilder,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PaliGemmaMultiModalProcessor,
|
||||
PaliGemmaProcessingInfo,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
|
||||
class ColPaliProcessingInfo(PaliGemmaProcessingInfo):
|
||||
"""Processing info for ColPali models.
|
||||
|
||||
ColPali models use a custom HuggingFace config (ColPaliConfig) that is
|
||||
not an instance of PaliGemmaConfig. We override get_hf_config() and
|
||||
get_hf_processor() to skip the strict type check.
|
||||
"""
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> PaliGemmaProcessor:
|
||||
# Force standard PaliGemmaProcessor even when trust_remote_code=True.
|
||||
return self.ctx.get_hf_processor(PaliGemmaProcessor, **kwargs)
|
||||
|
||||
|
||||
class ColPaliMultiModalProcessor(PaliGemmaMultiModalProcessor):
|
||||
"""Multimodal processor for ColPali."""
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if mm_data:
|
||||
# The ColPali tokenizer_config.json ships with a small default
|
||||
# max_length (50) that truncates the 1024 image tokens inserted
|
||||
# by PaliGemmaProcessor, causing a token-count mismatch.
|
||||
# vLLM enforces its own max_model_len, so we disable HF
|
||||
# truncation to keep all image + text tokens intact.
|
||||
tok_kwargs = dict(tok_kwargs, truncation=False)
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
ColPaliMultiModalProcessor,
|
||||
info=ColPaliProcessingInfo,
|
||||
dummy_inputs=PaliGemmaDummyInputsBuilder,
|
||||
)
|
||||
class ColPaliModel(
|
||||
PaliGemmaForConditionalGeneration,
|
||||
SupportsLateInteraction,
|
||||
):
|
||||
"""ColPali late interaction model for multi-modal retrieval/reranking.
|
||||
|
||||
This model extends PaliGemmaForConditionalGeneration with a ColBERT-style
|
||||
linear projection layer for per-token embeddings. It supports:
|
||||
- "token_embed" task: Per-token embeddings for late interaction scoring
|
||||
|
||||
The model produces L2-normalized per-token embeddings by:
|
||||
1. Running the PaliGemma backbone (vision + language) to get hidden states
|
||||
2. Projecting hidden states through a linear layer (hidden_size -> embed_dim)
|
||||
3. L2-normalizing the projected embeddings
|
||||
"""
|
||||
|
||||
# Mark this as a pooling model so vLLM routes to pooler path
|
||||
is_pooling_model = True
|
||||
|
||||
# Override hf_to_vllm_mapper to handle ColPali weight naming.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# HF transformers checkpoint (vidore/colpali-v1.3-hf)
|
||||
# Weights: vlm.vision_tower.*, vlm.language_model.*,
|
||||
# vlm.multi_modal_projector.*
|
||||
"vlm.vision_tower.": "vision_tower.",
|
||||
"vlm.language_model.": "language_model.",
|
||||
"vlm.multi_modal_projector.": "multi_modal_projector.",
|
||||
# colpali-engine checkpoint naming
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.language_model.": "language_model.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
hidden_size = getattr(config, "hidden_size", None)
|
||||
if hidden_size is None and hasattr(config, "text_config"):
|
||||
hidden_size = config.text_config.hidden_size
|
||||
if hidden_size is None:
|
||||
raise ValueError(
|
||||
"Unable to determine text hidden size from config. "
|
||||
"Expected 'hidden_size' or 'text_config.hidden_size'."
|
||||
)
|
||||
self._proj_hidden_size = hidden_size
|
||||
|
||||
# ColPali uses embedding_dim=128, but also check other naming variants
|
||||
self.embed_dim: int | None = (
|
||||
getattr(config, "embedding_dim", None)
|
||||
or getattr(config, "embed_dim", None)
|
||||
or getattr(config, "dim", None)
|
||||
or getattr(config, "projection_dim", None)
|
||||
or getattr(config, "colbert_dim", None)
|
||||
)
|
||||
|
||||
# Build the projection layer if embed_dim is known
|
||||
if self.embed_dim is not None:
|
||||
self.custom_text_proj = nn.Linear(
|
||||
hidden_size,
|
||||
self.embed_dim,
|
||||
bias=False,
|
||||
dtype=head_dtype,
|
||||
)
|
||||
else:
|
||||
# Will be created during load_weights when dim is inferred
|
||||
self.custom_text_proj = None
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = pooler_for_token_embed(
|
||||
pooler_config,
|
||||
projector=self.custom_text_proj,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors=None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Names used for the projection layer across different ColPali variants
|
||||
_PROJ_LAYER_NAMES = {
|
||||
"custom_text_proj", # vLLM internal naming
|
||||
"embedding_proj_layer", # colpali-engine / HF naming
|
||||
}
|
||||
|
||||
def _is_proj_weight(self, name: str) -> bool:
|
||||
"""Check if a weight name belongs to the projection layer."""
|
||||
return any(proj_name in name for proj_name in self._PROJ_LAYER_NAMES)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Load weights with special handling for ColPali projection layer."""
|
||||
weights_list = list(weights)
|
||||
proj_weights: list[tuple[str, torch.Tensor]] = []
|
||||
model_weights: list[tuple[str, torch.Tensor]] = []
|
||||
|
||||
for name, weight in weights_list:
|
||||
if self._is_proj_weight(name):
|
||||
proj_weights.append((name, weight))
|
||||
else:
|
||||
model_weights.append((name, weight))
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded = loader.load_weights(model_weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
if proj_weights:
|
||||
model_dtype = next(self.language_model.parameters()).dtype
|
||||
model_device = next(self.language_model.parameters()).device
|
||||
|
||||
for name, weight in proj_weights:
|
||||
if self.embed_dim is None and "weight" in name:
|
||||
self.embed_dim = weight.shape[0]
|
||||
has_bias = any("bias" in n for n, _ in proj_weights)
|
||||
self.custom_text_proj = nn.Linear(
|
||||
self._proj_hidden_size,
|
||||
self.embed_dim,
|
||||
bias=has_bias,
|
||||
dtype=model_dtype,
|
||||
)
|
||||
self.custom_text_proj.to(model_device)
|
||||
|
||||
if self.custom_text_proj is not None:
|
||||
param_name = name.split(".")[-1]
|
||||
param = getattr(self.custom_text_proj, param_name, None)
|
||||
if param is not None:
|
||||
weight = weight.to(device=param.device, dtype=param.dtype)
|
||||
default_weight_loader(param, weight)
|
||||
loaded.add(f"custom_text_proj.{param_name}")
|
||||
|
||||
# Update pooler projector for the lazy-creation path
|
||||
self.pooler.head.projector = self.custom_text_proj
|
||||
|
||||
# Mark pooler projector params as loaded
|
||||
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
|
||||
head = self.pooler.head
|
||||
projector = getattr(head, "projector", None)
|
||||
if projector is not None and isinstance(projector, nn.Module):
|
||||
for pname, _ in projector.named_parameters():
|
||||
loaded.add(f"pooler.head.projector.{pname}")
|
||||
|
||||
return loaded
|
||||
@@ -247,6 +247,7 @@ _EMBEDDING_MODELS = {
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
|
||||
"ColPaliForRetrieval": ("colpali", "ColPaliModel"),
|
||||
"LlavaNextForConditionalGeneration": (
|
||||
"llava_next",
|
||||
"LlavaNextForConditionalGeneration",
|
||||
|
||||
@@ -33,6 +33,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
|
||||
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
|
||||
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"clip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"colpali": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja",
|
||||
"deepseek_ocr2": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja",
|
||||
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
|
||||
|
||||
@@ -78,6 +78,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
bagel="BagelConfig",
|
||||
chatglm="ChatGLMConfig",
|
||||
colmodernvbert="ColModernVBertConfig",
|
||||
colpali="ColPaliConfig",
|
||||
colqwen3="ColQwen3Config",
|
||||
ops_colqwen3="OpsColQwen3Config",
|
||||
qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig",
|
||||
|
||||
@@ -20,6 +20,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
|
||||
"BagelConfig": "vllm.transformers_utils.configs.bagel",
|
||||
"ChatGLMConfig": "vllm.transformers_utils.configs.chatglm",
|
||||
"ColModernVBertConfig": "vllm.transformers_utils.configs.colmodernvbert",
|
||||
"ColPaliConfig": "vllm.transformers_utils.configs.colpali",
|
||||
"ColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
|
||||
"OpsColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
|
||||
"Qwen3VLNemotronEmbedConfig": "vllm.transformers_utils.configs.colqwen3",
|
||||
@@ -76,6 +77,7 @@ __all__ = [
|
||||
"BagelConfig",
|
||||
"ChatGLMConfig",
|
||||
"ColModernVBertConfig",
|
||||
"ColPaliConfig",
|
||||
"ColQwen3Config",
|
||||
"OpsColQwen3Config",
|
||||
"Qwen3VLNemotronEmbedConfig",
|
||||
|
||||
59
vllm/transformers_utils/configs/colpali.py
Normal file
59
vllm/transformers_utils/configs/colpali.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
ColPali configuration that extends PaliGemmaConfig with embedding projection
|
||||
fields. This allows ColPali models to be loaded without trust_remote_code
|
||||
by mapping their custom model_type (colpali) to a standard config class
|
||||
that vLLM understands.
|
||||
|
||||
Supported model_types:
|
||||
- colpali (vidore/colpali-v1.3-hf)
|
||||
"""
|
||||
|
||||
from transformers import PaliGemmaConfig
|
||||
|
||||
|
||||
class ColPaliConfig(PaliGemmaConfig):
|
||||
"""Configuration class for ColPali models.
|
||||
|
||||
Extends PaliGemmaConfig with additional fields used by ColPali variants
|
||||
for the embedding projection layer.
|
||||
"""
|
||||
|
||||
model_type = "colpali"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int | None = None,
|
||||
embed_dim: int | None = None,
|
||||
dim: int | None = None,
|
||||
projection_dim: int | None = None,
|
||||
colbert_dim: int | None = None,
|
||||
pooling: str | None = None,
|
||||
vlm_config: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Store embedding projection config fields
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.dim = dim
|
||||
self.projection_dim = projection_dim
|
||||
self.colbert_dim = colbert_dim
|
||||
self.pooling = pooling
|
||||
|
||||
# The HF checkpoint nests PaliGemma config inside "vlm_config".
|
||||
# Flatten it so PaliGemmaConfig receives vision_config, text_config,
|
||||
# image_token_index, etc. directly.
|
||||
# Use setdefault to avoid overwriting keys already set (e.g.
|
||||
# model_type="colpali" would be clobbered by "paligemma" from
|
||||
# vlm_config).
|
||||
if vlm_config is not None:
|
||||
vlm_dict = (
|
||||
vlm_config if isinstance(vlm_config, dict) else vlm_config.to_dict()
|
||||
)
|
||||
_conflicting = {"model_type", "_name_or_path"}
|
||||
for key, value in vlm_dict.items():
|
||||
if key not in _conflicting:
|
||||
kwargs.setdefault(key, value)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
Reference in New Issue
Block a user