[Model] Add ColPali late interaction model for multi-modal retrieval (#36818)

Signed-off-by: Nikita Sukharev <kaonael@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Nikita
2026-03-13 03:18:57 +01:00
committed by GitHub
parent 5e1a373d2e
commit 10f08dedfa
9 changed files with 634 additions and 0 deletions

View File

@@ -828,6 +828,7 @@ The following table lists those that are tested in vLLM.
| ------------ | ------ | ------ | ----------------- | -------------------- | ------------------------- |
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
| `ColPaliForRetrieval` | ColPali | T / I | `vidore/colpali-v1.3-hf` | | |
| `LlamaNemotronVLModel` | Llama Nemotron Embedding + SigLIP | T + I | `nvidia/llama-nemotron-embed-vl-1b-v2` | | |
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |

View File

@@ -0,0 +1,323 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ColPali late interaction model for multi-modal retrieval.
ColPali is a multi-vector retrieval model based on PaliGemma backbone
(SigLIP + Gemma) with ColBERT-style late interaction scoring (MaxSim).
It produces per-token embeddings for both text and image inputs.
"""
import base64
from io import BytesIO
import pytest
import torch
from PIL import Image
from vllm.entrypoints.chat_utils import (
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
)
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
from ....conftest import VllmRunner
MODELS = [
"vidore/colpali-v1.3-hf",
]
EMBED_DIMS = {
"vidore/colpali-v1.3-hf": 128,
}
TEXT_QUERIES = [
"What is the capital of France?",
"Describe the contents of the document.",
]
TEXT_DOCUMENTS = [
"The capital of France is Paris.",
"This document contains important financial data.",
]
DTYPE = "half"
GPU_MEMORY_UTILIZATION = 0.7
def _make_base64_image(
width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0)
) -> str:
"""Create a small solid-color PNG image and return its base64 data URI."""
img = Image.new("RGB", (width, height), color)
buf = BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
def _make_image_mm_param(
image_uri: str,
text: str | None = None,
) -> ScoreMultiModalParam:
"""Build a ScoreMultiModalParam containing an image (and optional text)."""
content: list = [
ChatCompletionContentPartImageParam(
type="image_url",
image_url={"url": image_uri},
),
]
if text is not None:
content.append(
ChatCompletionContentPartTextParam(type="text", text=text),
)
return ScoreMultiModalParam(content=content)
def _run_token_embed_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Verify per-token embedding shape and L2 normalization."""
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
assert len(outputs) == 1
emb = torch.tensor(outputs[0])
# Token embeddings should be 2D: [num_tokens, embed_dim]
assert emb.dim() == 2
assert emb.shape[1] == EMBED_DIMS[model]
assert emb.shape[0] > 1
# Verify L2 normalization
norms = torch.norm(emb, p=2, dim=-1)
torch.testing.assert_close(
norms,
torch.ones_like(norms),
rtol=1e-2,
atol=1e-2,
)
def _run_late_interaction_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Verify MaxSim scoring matches manual computation."""
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]])
q_emb = torch.tensor(q_outputs[0])
d_emb = torch.tensor(d_outputs[0])
manual_score = compute_maxsim_score(q_emb, d_emb).item()
vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0])
assert len(vllm_scores) == 1
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
def _run_relevance_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Verify that relevant documents score higher than irrelevant ones."""
query = "What is machine learning?"
documents = [
"Machine learning is a subset of artificial intelligence.",
"The weather forecast shows rain tomorrow.",
"Deep learning uses neural networks for complex tasks.",
]
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
scores = vllm_model.score(query, documents)
assert len(scores) == 3
assert scores[0] > scores[1], "ML doc should score higher than weather doc"
assert scores[2] > scores[1], "DL doc should score higher than weather doc"
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_token_embed(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_token_embed_test(vllm_runner, model, dtype=dtype)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_late_interaction_scoring(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_late_interaction_test(vllm_runner, model, dtype=dtype)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_relevance_ordering(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_relevance_test(vllm_runner, model, dtype=dtype)
# ── Multimodal scoring tests ────────────────────────────────
def _run_multimodal_text_query_image_docs_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Score a text query against image documents via the multimodal path."""
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
blue_image = _make_base64_image(64, 64, color=(0, 0, 255))
query = "Describe the red object"
image_docs = [
_make_image_mm_param(red_image),
_make_image_mm_param(blue_image),
]
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
scores = vllm_model.llm.score(query, image_docs)
assert len(scores) == 2
for s in scores:
assert isinstance(s.outputs.score, float)
def _run_multimodal_mixed_docs_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Score a text query against a mix of text and image documents."""
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
query = "What is the capital of France?"
documents: list = [
"The capital of France is Paris.",
_make_image_mm_param(red_image),
]
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
scores = vllm_model.llm.score(query, documents)
assert len(scores) == 2
for s in scores:
assert isinstance(s.outputs.score, float)
# Text document about France should score higher than a random image
assert scores[0].outputs.score > scores[1].outputs.score
def _run_multimodal_image_query_text_docs_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Score an image query against text documents."""
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
image_query = _make_image_mm_param(red_image, text="red color")
documents = [
"A bright red sports car.",
"The weather forecast shows rain tomorrow.",
]
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
) as vllm_model:
scores = vllm_model.llm.score(image_query, documents)
assert len(scores) == 2
for s in scores:
assert isinstance(s.outputs.score, float)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_multimodal_text_query_image_docs(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_multimodal_text_query_image_docs_test(vllm_runner, model, dtype=dtype)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_multimodal_mixed_docs(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_multimodal_mixed_docs_test(vllm_runner, model, dtype=dtype)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_multimodal_image_query_text_docs(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_multimodal_image_query_text_docs_test(vllm_runner, model, dtype=dtype)

View File

@@ -631,6 +631,7 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
"ColModernVBertForRetrieval": _HfExamplesInfo(
"ModernVBERT/colmodernvbert-merged",
),
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
"ColQwen3": _HfExamplesInfo(
"TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True
),

View File

@@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
ColPali late interaction model for multi-modal retrieval and reranking.
ColPali extends PaliGemma with a ColBERT-style late interaction head,
producing per-token embeddings for both text and image inputs. It uses
MaxSim scoring for retrieval/reranking tasks.
This model supports the "token_embed" pooling task and is designed for
multi-vector retrieval of documents containing both text and images.
Reference: https://arxiv.org/abs/2407.01449 (ColPali)
Based on: PaliGemma backbone (SigLIP + Gemma) with custom text projection
Target models:
- vidore/colpali-v1.3-hf
"""
from collections.abc import Iterable, Mapping
import torch
import torch.nn as nn
from transformers import BatchFeature, PaliGemmaProcessor
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from .interfaces import SupportsLateInteraction
from .interfaces_base import default_pooling_type
from .paligemma import (
PaliGemmaDummyInputsBuilder,
PaliGemmaForConditionalGeneration,
PaliGemmaMultiModalProcessor,
PaliGemmaProcessingInfo,
)
from .utils import AutoWeightsLoader, WeightsMapper
class ColPaliProcessingInfo(PaliGemmaProcessingInfo):
"""Processing info for ColPali models.
ColPali models use a custom HuggingFace config (ColPaliConfig) that is
not an instance of PaliGemmaConfig. We override get_hf_config() and
get_hf_processor() to skip the strict type check.
"""
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object) -> PaliGemmaProcessor:
# Force standard PaliGemmaProcessor even when trust_remote_code=True.
return self.ctx.get_hf_processor(PaliGemmaProcessor, **kwargs)
class ColPaliMultiModalProcessor(PaliGemmaMultiModalProcessor):
"""Multimodal processor for ColPali."""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
# The ColPali tokenizer_config.json ships with a small default
# max_length (50) that truncates the 1024 image tokens inserted
# by PaliGemmaProcessor, causing a token-count mismatch.
# vLLM enforces its own max_model_len, so we disable HF
# truncation to keep all image + text tokens intact.
tok_kwargs = dict(tok_kwargs, truncation=False)
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
@MULTIMODAL_REGISTRY.register_processor(
ColPaliMultiModalProcessor,
info=ColPaliProcessingInfo,
dummy_inputs=PaliGemmaDummyInputsBuilder,
)
class ColPaliModel(
PaliGemmaForConditionalGeneration,
SupportsLateInteraction,
):
"""ColPali late interaction model for multi-modal retrieval/reranking.
This model extends PaliGemmaForConditionalGeneration with a ColBERT-style
linear projection layer for per-token embeddings. It supports:
- "token_embed" task: Per-token embeddings for late interaction scoring
The model produces L2-normalized per-token embeddings by:
1. Running the PaliGemma backbone (vision + language) to get hidden states
2. Projecting hidden states through a linear layer (hidden_size -> embed_dim)
3. L2-normalizing the projected embeddings
"""
# Mark this as a pooling model so vLLM routes to pooler path
is_pooling_model = True
# Override hf_to_vllm_mapper to handle ColPali weight naming.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# HF transformers checkpoint (vidore/colpali-v1.3-hf)
# Weights: vlm.vision_tower.*, vlm.language_model.*,
# vlm.multi_modal_projector.*
"vlm.vision_tower.": "vision_tower.",
"vlm.language_model.": "language_model.",
"vlm.multi_modal_projector.": "multi_modal_projector.",
# colpali-engine checkpoint naming
"model.vision_tower.": "vision_tower.",
"model.language_model.": "language_model.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
head_dtype = vllm_config.model_config.head_dtype
hidden_size = getattr(config, "hidden_size", None)
if hidden_size is None and hasattr(config, "text_config"):
hidden_size = config.text_config.hidden_size
if hidden_size is None:
raise ValueError(
"Unable to determine text hidden size from config. "
"Expected 'hidden_size' or 'text_config.hidden_size'."
)
self._proj_hidden_size = hidden_size
# ColPali uses embedding_dim=128, but also check other naming variants
self.embed_dim: int | None = (
getattr(config, "embedding_dim", None)
or getattr(config, "embed_dim", None)
or getattr(config, "dim", None)
or getattr(config, "projection_dim", None)
or getattr(config, "colbert_dim", None)
)
# Build the projection layer if embed_dim is known
if self.embed_dim is not None:
self.custom_text_proj = nn.Linear(
hidden_size,
self.embed_dim,
bias=False,
dtype=head_dtype,
)
else:
# Will be created during load_weights when dim is inferred
self.custom_text_proj = None
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = pooler_for_token_embed(
pooler_config,
projector=self.custom_text_proj,
)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors=None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
return super().forward(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
# Names used for the projection layer across different ColPali variants
_PROJ_LAYER_NAMES = {
"custom_text_proj", # vLLM internal naming
"embedding_proj_layer", # colpali-engine / HF naming
}
def _is_proj_weight(self, name: str) -> bool:
"""Check if a weight name belongs to the projection layer."""
return any(proj_name in name for proj_name in self._PROJ_LAYER_NAMES)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights with special handling for ColPali projection layer."""
weights_list = list(weights)
proj_weights: list[tuple[str, torch.Tensor]] = []
model_weights: list[tuple[str, torch.Tensor]] = []
for name, weight in weights_list:
if self._is_proj_weight(name):
proj_weights.append((name, weight))
else:
model_weights.append((name, weight))
loader = AutoWeightsLoader(self)
loaded = loader.load_weights(model_weights, mapper=self.hf_to_vllm_mapper)
if proj_weights:
model_dtype = next(self.language_model.parameters()).dtype
model_device = next(self.language_model.parameters()).device
for name, weight in proj_weights:
if self.embed_dim is None and "weight" in name:
self.embed_dim = weight.shape[0]
has_bias = any("bias" in n for n, _ in proj_weights)
self.custom_text_proj = nn.Linear(
self._proj_hidden_size,
self.embed_dim,
bias=has_bias,
dtype=model_dtype,
)
self.custom_text_proj.to(model_device)
if self.custom_text_proj is not None:
param_name = name.split(".")[-1]
param = getattr(self.custom_text_proj, param_name, None)
if param is not None:
weight = weight.to(device=param.device, dtype=param.dtype)
default_weight_loader(param, weight)
loaded.add(f"custom_text_proj.{param_name}")
# Update pooler projector for the lazy-creation path
self.pooler.head.projector = self.custom_text_proj
# Mark pooler projector params as loaded
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
head = self.pooler.head
projector = getattr(head, "projector", None)
if projector is not None and isinstance(projector, nn.Module):
for pname, _ in projector.named_parameters():
loaded.add(f"pooler.head.projector.{pname}")
return loaded

View File

@@ -247,6 +247,7 @@ _EMBEDDING_MODELS = {
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
# [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
"ColPaliForRetrieval": ("colpali", "ColPaliModel"),
"LlavaNextForConditionalGeneration": (
"llava_next",
"LlavaNextForConditionalGeneration",

View File

@@ -33,6 +33,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"clip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"colpali": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja",
"deepseek_ocr2": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja",
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",

View File

@@ -78,6 +78,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
bagel="BagelConfig",
chatglm="ChatGLMConfig",
colmodernvbert="ColModernVBertConfig",
colpali="ColPaliConfig",
colqwen3="ColQwen3Config",
ops_colqwen3="OpsColQwen3Config",
qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig",

View File

@@ -20,6 +20,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"BagelConfig": "vllm.transformers_utils.configs.bagel",
"ChatGLMConfig": "vllm.transformers_utils.configs.chatglm",
"ColModernVBertConfig": "vllm.transformers_utils.configs.colmodernvbert",
"ColPaliConfig": "vllm.transformers_utils.configs.colpali",
"ColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
"OpsColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
"Qwen3VLNemotronEmbedConfig": "vllm.transformers_utils.configs.colqwen3",
@@ -76,6 +77,7 @@ __all__ = [
"BagelConfig",
"ChatGLMConfig",
"ColModernVBertConfig",
"ColPaliConfig",
"ColQwen3Config",
"OpsColQwen3Config",
"Qwen3VLNemotronEmbedConfig",

View File

@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
ColPali configuration that extends PaliGemmaConfig with embedding projection
fields. This allows ColPali models to be loaded without trust_remote_code
by mapping their custom model_type (colpali) to a standard config class
that vLLM understands.
Supported model_types:
- colpali (vidore/colpali-v1.3-hf)
"""
from transformers import PaliGemmaConfig
class ColPaliConfig(PaliGemmaConfig):
"""Configuration class for ColPali models.
Extends PaliGemmaConfig with additional fields used by ColPali variants
for the embedding projection layer.
"""
model_type = "colpali"
def __init__(
self,
embedding_dim: int | None = None,
embed_dim: int | None = None,
dim: int | None = None,
projection_dim: int | None = None,
colbert_dim: int | None = None,
pooling: str | None = None,
vlm_config: dict | None = None,
**kwargs,
):
# Store embedding projection config fields
self.embedding_dim = embedding_dim
self.embed_dim = embed_dim
self.dim = dim
self.projection_dim = projection_dim
self.colbert_dim = colbert_dim
self.pooling = pooling
# The HF checkpoint nests PaliGemma config inside "vlm_config".
# Flatten it so PaliGemmaConfig receives vision_config, text_config,
# image_token_index, etc. directly.
# Use setdefault to avoid overwriting keys already set (e.g.
# model_type="colpali" would be clobbered by "paligemma" from
# vlm_config).
if vlm_config is not None:
vlm_dict = (
vlm_config if isinstance(vlm_config, dict) else vlm_config.to_dict()
)
_conflicting = {"model_type", "_name_or_path"}
for key, value in vlm_dict.items():
if key not in _conflicting:
kwargs.setdefault(key, value)
super().__init__(**kwargs)