[Model] Add ColQwen3.5 4.5B support (#36887)

Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Athrael Soju
2026-03-17 21:17:02 +00:00
committed by GitHub
parent b5ca9c3557
commit c0745a851a
8 changed files with 579 additions and 0 deletions

View File

@@ -625,6 +625,46 @@ curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
}'
```
### ColQwen3.5 Multi-Modal Late Interaction Models
ColQwen3.5 is based on [ColPali](https://arxiv.org/abs/2407.01449), extending ColBERT's late interaction approach to **multi-modal** inputs. It uses the Qwen3.5 hybrid backbone (linear + full attention) and produces per-token L2-normalized vectors for MaxSim scoring.
| Architecture | Backbone | Example HF Models |
| - | - | - |
| `ColQwen3_5` | Qwen3.5 | `athrael-soju/colqwen3.5-4.5B` |
Start the server:
```shell
vllm serve athrael-soju/colqwen3.5-4.5B --max-model-len 4096
```
Then you can use the rerank endpoint:
```shell
curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
"model": "athrael-soju/colqwen3.5-4.5B",
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks."
]
}'
```
Or the score endpoint:
```shell
curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
"model": "athrael-soju/colqwen3.5-4.5B",
"text_1": "What is the capital of France?",
"text_2": ["The capital of France is Paris.", "Python is a programming language."]
}'
```
An example can be found here: [examples/pooling/score/colqwen3_5_rerank_online.py](../../examples/pooling/score/colqwen3_5_rerank_online.py)
### BAAI/bge-m3
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`

View File

@@ -834,6 +834,7 @@ The following table lists those that are tested in vLLM.
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
| `ColPaliForRetrieval` | ColPali | T / I | `vidore/colpali-v1.3-hf` | | |
| `ColQwen3_5` | ColQwen3.5 | T + I + V | `athrael-soju/colqwen3.5-4.5B-v3` | | |
| `LlamaNemotronVLModel` | Llama Nemotron Embedding + SigLIP | T + I | `nvidia/llama-nemotron-embed-vl-1b-v2` | | |
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using ColQwen3.5 late interaction model for reranking.
ColQwen3.5 is a multi-modal ColBERT-style model based on Qwen3.5.
It produces per-token embeddings and uses MaxSim scoring for retrieval
and reranking. Supports both text and image inputs.
Start the server with:
vllm serve athrael-soju/colqwen3.5-4.5B --max-model-len 4096
Then run this script:
python colqwen3_5_rerank_online.py
"""
import requests
MODEL = "athrael-soju/colqwen3.5-4.5B"
BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"}
def rerank_text():
"""Text-only reranking via /rerank endpoint."""
print("=" * 60)
print("1. Text reranking (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
],
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print("\n Ranked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text():
"""Text-only scoring via /score endpoint."""
print()
print("=" * 60)
print("2. Text scoring (/score)")
print("=" * 60)
query = "What is the capital of France?"
documents = [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
]
data = {
"model": MODEL,
"text_1": query,
"text_2": documents,
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Query: {query}\n")
for item in result["data"]:
idx = item["index"]
score = item["score"]
print(f" Doc {idx} (score={score:.4f}): {documents[idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text_top_n():
"""Text reranking with top_n filtering via /rerank endpoint."""
print()
print("=" * 60)
print("3. Text reranking with top_n=2 (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is the capital of France?",
"documents": [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
"The Eiffel Tower is in Paris.",
],
"top_n": 2,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Top {data['top_n']} results:")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def main():
rerank_text()
score_text()
score_text_top_n()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ColQwen3.5 late interaction model for multi-modal retrieval.
ColQwen3.5 is a multi-vector retrieval model based on Qwen3.5 backbone with
ColBERT-style late interaction scoring (MaxSim). It produces per-token
embeddings for both text and image inputs.
"""
import pytest
import torch
from ....conftest import VllmRunner
MODELS = [
"athrael-soju/colqwen3.5-4.5B-v3",
]
EMBED_DIMS = {
"athrael-soju/colqwen3.5-4.5B-v3": 320,
}
TEXT_QUERIES = [
"What is the capital of France?",
"Describe the contents of the document.",
]
TEXT_DOCUMENTS = [
"The capital of France is Paris.",
"This document contains important financial data.",
]
DTYPE = "half"
def _run_token_embed_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Verify per-token embedding shape and L2 normalization."""
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
) as vllm_model:
outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
assert len(outputs) == 1
emb = torch.tensor(outputs[0])
# Token embeddings should be 2D: [num_tokens, embed_dim]
assert emb.dim() == 2
assert emb.shape[1] == EMBED_DIMS[model]
assert emb.shape[0] > 1
# Verify L2 normalization
norms = torch.norm(emb, p=2, dim=-1)
torch.testing.assert_close(
norms,
torch.ones_like(norms),
rtol=1e-2,
atol=1e-2,
)
def _run_late_interaction_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Verify MaxSim scoring matches manual computation."""
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
) as vllm_model:
q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]])
q_emb = torch.tensor(q_outputs[0])
d_emb = torch.tensor(d_outputs[0])
manual_score = compute_maxsim_score(q_emb, d_emb).item()
vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0])
assert len(vllm_scores) == 1
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
def _run_relevance_test(
vllm_runner: type[VllmRunner],
model: str,
*,
dtype: str,
) -> None:
"""Verify that relevant documents score higher than irrelevant ones."""
query = "What is machine learning?"
documents = [
"Machine learning is a subset of artificial intelligence.",
"The weather forecast shows rain tomorrow.",
"Deep learning uses neural networks for complex tasks.",
]
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
) as vllm_model:
scores = vllm_model.score(query, documents)
assert len(scores) == 3
assert scores[0] > scores[1], "ML doc should score higher than weather doc"
assert scores[2] > scores[1], "DL doc should score higher than weather doc"
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colqwen3_5_token_embed(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_token_embed_test(vllm_runner, model, dtype=dtype)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colqwen3_5_late_interaction_scoring(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_late_interaction_test(vllm_runner, model, dtype=dtype)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colqwen3_5_relevance_ordering(
vllm_runner,
model: str,
dtype: str,
) -> None:
_run_relevance_test(vllm_runner, model, dtype=dtype)

View File

@@ -639,6 +639,11 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
"OpsColQwen3Model": _HfExamplesInfo(
"OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True
),
"ColQwen3_5": _HfExamplesInfo(
"athrael-soju/colqwen3.5-4.5B-v3",
trust_remote_code=True,
max_model_len=4096,
),
"Qwen3VLNemotronEmbedModel": _HfExamplesInfo(
"nvidia/nemotron-colembed-vl-4b-v2",
),

View File

@@ -0,0 +1,246 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
ColQwen3.5 late interaction model for multi-modal retrieval and reranking.
ColQwen3.5 extends Qwen3.5 with a ColBERT-style late interaction head,
producing per-token embeddings for both text and image inputs. It uses
MaxSim scoring for retrieval/reranking tasks.
This model supports the "token_embed" pooling task and is designed for
multi-vector retrieval of documents containing both text and images.
Reference: https://arxiv.org/abs/2407.01449 (ColPali)
Based on: Qwen3.5 backbone with custom text projection
Target models:
- athrael-soju/colqwen3.5-4.5B-v3
"""
from collections.abc import Iterable, Mapping
import torch
import torch.nn as nn
from transformers.models.qwen3_vl import Qwen3VLProcessor
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from .interfaces import SupportsLateInteraction
from .interfaces_base import default_pooling_type
from .qwen2_vl import Qwen2VLMultiModalDataParser
from .qwen3_5 import (
Qwen3_5ForConditionalGeneration,
Qwen3_5ProcessingInfo,
)
from .qwen3_vl import (
Qwen3VLDummyInputsBuilder,
Qwen3VLMultiModalProcessor,
)
from .utils import AutoWeightsLoader, WeightsMapper
class ColQwen3_5ProcessingInfo(Qwen3_5ProcessingInfo):
"""Processing info for ColQwen3.5 models.
ColQwen3.5 models use custom HuggingFace processors (e.g.
ColQwen3_5Processor) that are incompatible with vLLM's
Qwen3VLMultiModalProcessor. We override get_hf_config() and
get_hf_processor() to skip the strict type check and force the
standard Qwen3VLProcessor.
"""
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
return self.ctx.get_hf_processor(
Qwen3VLProcessor,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
@property
def _supports_video(self) -> bool:
"""Check if the HF processor supports video inputs."""
return hasattr(self.get_hf_processor(), "video_processor")
def get_video_processor(self, **kwargs: object):
if not self._supports_video:
raise AttributeError(
f"The processor for {self.ctx.model_config.model} does not "
"support video inputs (no video_processor attribute)."
)
return self.get_hf_processor(**kwargs).video_processor # type: ignore[attr-defined]
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
limits: dict[str, int | None] = {"image": None}
if self._supports_video:
limits["video"] = None
return limits
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
max_image_tokens = self.get_max_image_tokens()
result: dict[str, int] = {"image": max_image_tokens}
if self._supports_video:
max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
result["video"] = max_video_tokens
return result
def get_data_parser(self):
hf_config = self.get_hf_config()
spatial_merge_size = hf_config.vision_config.spatial_merge_size
return Qwen2VLMultiModalDataParser(
spatial_merge_size,
video_needs_metadata=self._supports_video,
expected_hidden_size=self._get_expected_hidden_size(),
)
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
@MULTIMODAL_REGISTRY.register_processor(
Qwen3VLMultiModalProcessor,
info=ColQwen3_5ProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class ColQwen3_5Model(
Qwen3_5ForConditionalGeneration,
SupportsLateInteraction,
):
"""ColQwen3.5 late interaction model for multi-modal retrieval/reranking.
This model extends Qwen3_5ForConditionalGeneration with a ColBERT-style
linear projection layer for per-token embeddings. It supports:
- "token_embed" task: Per-token embeddings for late interaction scoring
The model produces per-token embeddings by:
1. Running the Qwen3.5 backbone (vision + language) to get hidden states
2. Projecting hidden states through a linear layer (hidden_size -> embed_dim)
3. L2 normalization is handled by the pooler via PoolerNormalize
Attributes:
custom_text_proj: Linear projection from hidden_size to embed_dim
"""
# Mark this as a pooling model so vLLM routes to pooler path
is_pooling_model = True
# Override hf_to_vllm_mapper to handle ColQwen3.5 weight naming.
# ColPali saves weights as "language_model.*" but vLLM's
# Qwen3_5ForCausalLM has them under "language_model.model.*".
# Visual weights ("visual.*") already match the vLLM module path.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.": "language_model.model.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
head_dtype = vllm_config.model_config.head_dtype
hidden_size = getattr(config, "hidden_size", None)
if hidden_size is None and hasattr(config, "text_config"):
hidden_size = config.text_config.hidden_size
if hidden_size is None:
raise ValueError(
"Unable to determine text hidden size from config. "
"Expected 'hidden_size' or 'text_config.hidden_size'."
)
# (ColPali: dim, projection_dim, colbert_dim)
self.embed_dim: int = (
getattr(config, "embed_dim", None)
or getattr(config, "dims", None)
or getattr(config, "dim", None)
or getattr(config, "projection_dim", None)
or getattr(config, "colbert_dim", None)
or 128 # default from reference implementation
)
self.custom_text_proj = nn.Linear(
hidden_size,
self.embed_dim,
bias=False,
dtype=head_dtype,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = pooler_for_token_embed(
pooler_config,
projector=None,
)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors=None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
"""Run forward pass producing per-token embeddings."""
hidden_states = super().forward(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
if not isinstance(hidden_states, torch.Tensor):
return hidden_states # type: ignore
proj_dtype = self.custom_text_proj.weight.dtype
if hidden_states.dtype != proj_dtype:
hidden_states = hidden_states.to(proj_dtype)
# Project to embedding dimension (normalization handled by pooler)
return self.custom_text_proj(hidden_states)
# Names used for the projection layer across different ColQwen3.5 variants
_PROJ_LAYER_NAMES = {
"custom_text_proj", # ColPali naming
"embedding_proj_layer", # Alternative naming
}
def _is_proj_weight(self, name: str) -> bool:
"""Check if a weight name belongs to the projection layer."""
return any(proj_name in name for proj_name in self._PROJ_LAYER_NAMES)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights with special handling for projection layer."""
weights_list = list(weights)
proj_weights: list[tuple[str, torch.Tensor]] = []
model_weights: list[tuple[str, torch.Tensor]] = []
for name, weight in weights_list:
if self._is_proj_weight(name):
proj_weights.append((name, weight))
else:
model_weights.append((name, weight))
loader = AutoWeightsLoader(
self,
skip_prefixes=["mtp."],
)
loaded = loader.load_weights(model_weights, mapper=self.hf_to_vllm_mapper)
for name, weight in proj_weights:
param_name = name.split(".")[-1]
param = getattr(self.custom_text_proj, param_name, None)
if param is not None:
weight = weight.to(device=param.device, dtype=param.dtype)
default_weight_loader(param, weight)
loaded.add(f"custom_text_proj.{param_name}")
return loaded

View File

@@ -647,6 +647,7 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"ColQwen3_5": Qwen3_5ForConditionalGenerationConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig,

View File

@@ -274,8 +274,10 @@ _LATE_INTERACTION_MODELS = {
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
# [Multimodal]
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
"ColPaliForRetrieval": ("colpali", "ColPaliModel"),
"ColQwen3": ("colqwen3", "ColQwen3Model"),
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
"ColQwen3_5": ("colqwen3_5", "ColQwen3_5Model"),
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
}