Extend ColBERT support to non-standard BERT backbones (#34170)

Signed-off-by: Ilya Boytsov <ilya.boytsov@aleph-alpha.com>
This commit is contained in:
Ilya Boytsov
2026-02-13 10:53:09 +01:00
committed by GitHub
parent 0916e7960b
commit 071d863e20
9 changed files with 775 additions and 291 deletions

View File

@@ -311,20 +311,31 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed
[ColBERT](https://arxiv.org/abs/2004.12832) (Contextualized Late Interaction over BERT) is a retrieval model that uses per-token embeddings and MaxSim scoring for document ranking. Unlike single-vector embedding models, ColBERT retains token-level representations and computes relevance scores through late interaction, providing better accuracy while being more efficient than cross-encoders.
vLLM supports ColBERT models for reranking tasks, automatically applying MaxSim scoring for query-document relevance:
vLLM supports ColBERT models with multiple encoder backbones:
| Architecture | Backbone | Example HF Models |
|---|---|---|
| `HF_ColBERT` | BERT | `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0` |
| `ColBERTModernBertModel` | ModernBERT | `lightonai/GTE-ModernColBERT-v1` |
| `ColBERTJinaRobertaModel` | Jina XLM-RoBERTa | `jinaai/jina-colbert-v2` |
**BERT-based ColBERT** models work out of the box:
```shell
vllm serve answerdotai/answerai-colbert-small-v1
```
Currently supports ColBERT models with standard BERT encoders (e.g., `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0`).
ColBERT models with modified encoder architectures are not yet supported, including BERT variants with rotary embeddings (e.g., `jinaai/jina-colbert-v2`) or other custom encoders (e.g., `LiquidAI/LFM2-ColBERT-350M`).
If your standard BERT ColBERT model's config doesn't specify the architecture as `HF_ColBERT`, override it with:
For **non-BERT backbones**, use `--hf-overrides` to set the correct architecture:
```shell
vllm serve your-colbert-model --hf-overrides '{"architectures": ["HF_ColBERT"]}'
# ModernBERT backbone
vllm serve lightonai/GTE-ModernColBERT-v1 \
--hf-overrides '{"architectures": ["ColBERTModernBertModel"]}'
# Jina XLM-RoBERTa backbone
vllm serve jinaai/jina-colbert-v2 \
--hf-overrides '{"architectures": ["ColBERTJinaRobertaModel"]}' \
--trust-remote-code
```
Then you can use the rerank endpoint:

View File

@@ -1,15 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using ColBERT late interaction model for reranking.
Example of using ColBERT late interaction models for reranking and scoring.
ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings
and MaxSim scoring for document reranking, providing better accuracy than
single-vector models while being more efficient than cross-encoders.
Start the server with:
vLLM supports ColBERT with multiple encoder backbones. Start the server
with one of the following:
# BERT backbone (works out of the box)
vllm serve answerdotai/answerai-colbert-small-v1
# ModernBERT backbone
vllm serve lightonai/GTE-ModernColBERT-v1 \
--hf-overrides '{"architectures": ["ColBERTModernBertModel"]}'
# Jina XLM-RoBERTa backbone
vllm serve jinaai/jina-colbert-v2 \
--hf-overrides '{"architectures": ["ColBERTJinaRobertaModel"]}' \
--trust-remote-code
Then run this script:
python colbert_rerank_online.py
"""
@@ -18,39 +30,62 @@ import json
import requests
url = "http://127.0.0.1:8000/rerank"
# Change this to match the model you started the server with
MODEL = "answerdotai/answerai-colbert-small-v1"
BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {
"model": "answerdotai/answerai-colbert-small-v1",
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
],
}
documents = [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
]
def rerank_example():
"""Use the /rerank endpoint to rank documents by query relevance."""
print("=== Rerank Example ===")
data = {
"model": MODEL,
"query": "What is machine learning?",
"documents": documents,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
result = response.json()
print(json.dumps(result, indent=2))
print("\nRanked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" Score {score:.4f}: {documents[doc_idx]}")
def score_example():
"""Use the /score endpoint for pairwise query-document scoring."""
print("\n=== Score Example ===")
data = {
"model": MODEL,
"text_1": "What is machine learning?",
"text_2": [
"Machine learning is a subset of AI.",
"The weather is sunny.",
],
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
result = response.json()
print(json.dumps(result, indent=2))
def main():
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
print("ColBERT Rerank Request successful!")
result = response.json()
print(json.dumps(result, indent=2))
# Show ranked results
print("\nRanked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" Score {score:.4f}: {data['documents'][doc_idx]}")
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
rerank_example()
score_example()
if __name__ == "__main__":

View File

@@ -8,10 +8,8 @@ import requests
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
MODEL_NAME = "answerdotai/answerai-colbert-small-v1"
COLBERT_DIM = 96 # This model uses 96-dimensional output
DTYPE = "half"
COLBERT_DIM = 96
MAX_MODEL_LEN = 512
@@ -26,129 +24,119 @@ def server():
yield remote_server
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_rerank(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT rerank endpoint."""
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]
class TestColBERTOnline:
def test_rerank(self, server: RemoteOpenAIServer):
"""Test ColBERT rerank endpoint."""
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]
rerank_response = requests.post(
server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents,
},
)
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())
rerank_response = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": documents,
},
)
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 2
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 2
# The relevant document (Paris) should have higher score
paris_result = next(r for r in rerank.results if r.index == 1)
brazil_result = next(r for r in rerank.results if r.index == 0)
paris_result = next(r for r in rerank.results if r.index == 1)
brazil_result = next(r for r in rerank.results if r.index == 0)
assert paris_result.relevance_score > brazil_result.relevance_score
assert paris_result.relevance_score > brazil_result.relevance_score
def test_rerank_top_n(self, server: RemoteOpenAIServer):
"""Test ColBERT rerank with top_n parameter."""
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
"Machine learning is a field of AI.",
]
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_rerank_top_n(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT rerank with top_n parameter."""
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
"Machine learning is a field of AI.",
]
rerank_response = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": documents,
"top_n": 2,
},
)
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())
rerank_response = requests.post(
server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents,
"top_n": 2,
},
)
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())
assert len(rerank.results) == 2
assert rerank.results[0].index == 1
assert len(rerank.results) == 2
# Top result should be about Paris (index 1)
assert rerank.results[0].index == 1
def test_score(self, server: RemoteOpenAIServer):
"""Test ColBERT score endpoint."""
text_1 = "What is the capital of France?"
text_2 = ["The capital of France is Paris.", "Python is a language."]
score_response = requests.post(
server.url_for("score"),
json={
"model": MODEL_NAME,
"text_1": text_1,
"text_2": text_2,
},
)
score_response.raise_for_status()
score = ScoreResponse.model_validate(score_response.json())
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_score(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT score endpoint."""
text_1 = "What is the capital of France?"
text_2 = ["The capital of France is Paris.", "Python is a language."]
assert score.id is not None
assert score.data is not None
assert len(score.data) == 2
score_response = requests.post(
server.url_for("score"),
json={
"model": model_name,
"text_1": text_1,
"text_2": text_2,
},
)
score_response.raise_for_status()
score = ScoreResponse.model_validate(score_response.json())
assert score.data[0].score > score.data[1].score
assert score.id is not None
assert score.data is not None
assert len(score.data) == 2
def test_token_embed(self, server: RemoteOpenAIServer):
"""Test ColBERT token_embed task via pooling endpoint."""
text = "What is the capital of France?"
# The relevant document should have higher score
assert score.data[0].score > score.data[1].score
pooling_response = requests.post(
server.url_for("pooling"),
json={
"model": MODEL_NAME,
"input": text,
"task": "token_embed",
},
)
pooling_response.raise_for_status()
pooling = pooling_response.json()
assert "data" in pooling
assert len(pooling["data"]) == 1
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT token_embed task via pooling endpoint."""
text = "What is the capital of France?"
embeddings = pooling["data"][0]["data"]
assert isinstance(embeddings, list)
assert len(embeddings) > 0
assert len(embeddings[0]) == COLBERT_DIM
pooling_response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": "token_embed",
},
)
pooling_response.raise_for_status()
pooling = pooling_response.json()
def test_embed_not_supported(self, server: RemoteOpenAIServer):
"""Test that ColBERT model does not support 'embed' task."""
task = "embed"
text = "What is the capital of France?"
assert "data" in pooling
assert len(pooling["data"]) == 1
response = requests.post(
server.url_for("pooling"),
json={
"model": MODEL_NAME,
"input": text,
"task": task,
},
)
# Token embeddings should be 2D
embeddings = pooling["data"][0]["data"]
assert isinstance(embeddings, list)
assert len(embeddings) > 0 # Should have tokens
assert len(embeddings[0]) == COLBERT_DIM
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
"""Test that ColBERT model does not support 'embed' task."""
task = "embed"
text = "What is the capital of France?"
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Unsupported task: {task!r}"
)

View File

@@ -1,16 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ColBERT late interaction scoring."""
"""Tests for ColBERT late interaction scoring.
Tests are parametrized across multiple ColBERT backbones to ensure the
generic ColBERT support works with different encoder architectures.
"""
import pytest
import torch
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
# suitable for testing (based on BERT-base)
COLBERT_MODEL = "answerdotai/answerai-colbert-small-v1"
COLBERT_DIM = 96 # This model uses 96-dimensional output
# -----------------------------------------------------------------------
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
# -----------------------------------------------------------------------
COLBERT_MODELS = {
"bert": {
"model": "answerdotai/answerai-colbert-small-v1",
"colbert_dim": 96,
"max_model_len": 512,
"extra_kwargs": {},
},
"modernbert": {
"model": "lightonai/GTE-ModernColBERT-v1",
"colbert_dim": 128,
"max_model_len": 299,
"extra_kwargs": {
"hf_overrides": {
"architectures": ["ColBERTModernBertModel"],
},
},
},
"jina": {
"model": "jinaai/jina-colbert-v2",
"colbert_dim": 128,
"max_model_len": 8192,
"extra_kwargs": {
"hf_overrides": {
"architectures": ["ColBERTJinaRobertaModel"],
},
},
},
}
TEXTS_1 = [
"What is the capital of France?",
@@ -25,80 +56,121 @@ TEXTS_2 = [
DTYPE = "half"
# -----------------------------------------------------------------------
# Fixtures
# -----------------------------------------------------------------------
@pytest.fixture(params=list(COLBERT_MODELS.keys()), scope="module")
def colbert_spec(request):
"""Return the model spec dict for the current parametrization."""
return COLBERT_MODELS[request.param]
@pytest.fixture(scope="module")
def colbert_model_name():
return COLBERT_MODEL
def colbert_model_name(colbert_spec):
return colbert_spec["model"]
def test_colbert_token_embed(vllm_runner, colbert_model_name):
@pytest.fixture(scope="module")
def colbert_dim(colbert_spec):
return colbert_spec["colbert_dim"]
@pytest.fixture(scope="module")
def colbert_max_model_len(colbert_spec):
return colbert_spec["max_model_len"]
@pytest.fixture(scope="module")
def colbert_extra_kwargs(colbert_spec):
return colbert_spec["extra_kwargs"]
# -----------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------
def test_colbert_token_embed(
vllm_runner,
colbert_model_name,
colbert_dim,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test that ColBERT model produces token embeddings."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
max_model_len=colbert_max_model_len,
enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model:
# Get token embeddings for a single text
outputs = vllm_model.token_embed([TEXTS_1[0]])
assert len(outputs) == 1
# Token embeddings should be 2D: [num_tokens, colbert_dim]
emb = torch.tensor(outputs[0])
assert emb.dim() == 2
assert emb.shape[1] == COLBERT_DIM
# Should have at least a few tokens
assert emb.shape[1] == colbert_dim
assert emb.shape[0] > 1
def test_colbert_late_interaction_1_to_1(vllm_runner, colbert_model_name):
def test_colbert_late_interaction_1_to_1(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test ColBERT late interaction scoring with 1:1 query-document pair."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
max_model_len=colbert_max_model_len,
enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed([TEXTS_2[0]])
q_emb = torch.tensor(q_outputs[0])
d_emb = torch.tensor(d_outputs[0])
# Compute MaxSim manually
manual_score = compute_maxsim_score(q_emb, d_emb).item()
# Use the score API (which should internally use _late_interaction_score)
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])
assert len(vllm_scores) == 1
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name):
def test_colbert_late_interaction_1_to_N(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test ColBERT late interaction scoring with 1:N query-documents."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
max_model_len=colbert_max_model_len,
enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed(TEXTS_2)
q_emb = torch.tensor(q_outputs[0])
# Compute MaxSim manually for each document
manual_scores = []
for d_out in d_outputs:
d_emb = torch.tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
# Use the score API
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)
assert len(vllm_scores) == 2
@@ -106,27 +178,30 @@ def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name):
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name):
def test_colbert_late_interaction_N_to_N(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test ColBERT late interaction scoring with N:N query-documents."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
max_model_len=colbert_max_model_len,
enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed(TEXTS_1)
d_outputs = vllm_model.token_embed(TEXTS_2)
# Compute MaxSim manually for each pair
manual_scores = []
for q_out, d_out in zip(q_outputs, d_outputs):
q_emb = torch.tensor(q_out)
d_emb = torch.tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
# Use the score API
vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)
assert len(vllm_scores) == 2
@@ -134,8 +209,13 @@ def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name):
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
def test_colbert_relevance_ordering(vllm_runner, colbert_model_name):
"""Test that ColBERT scores relevant documents higher than irrelevant ones."""
def test_colbert_relevance_ordering(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test that ColBERT scores relevant documents higher than irrelevant."""
query = "What is machine learning?"
documents = [
"Machine learning is a subset of artificial intelligence.",
@@ -147,48 +227,73 @@ def test_colbert_relevance_ordering(vllm_runner, colbert_model_name):
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
max_model_len=colbert_max_model_len,
enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model:
scores = vllm_model.score(query, documents)
assert len(scores) == 3
# ML-related documents should score higher than unrelated Python doc
# Document 0 (ML definition) should be most relevant
# Document 2 (Deep learning) should also be relevant
# Document 1 (Python) should be least relevant
assert scores[0] > scores[1], "ML doc should score higher than Python doc"
assert scores[2] > scores[1], "DL doc should score higher than Python doc"
def test_colbert_embed_not_supported(vllm_runner, colbert_model_name):
def test_colbert_embed_not_supported(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test that ColBERT model does not support 'embed' task."""
with (
vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
max_model_len=colbert_max_model_len,
enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model,
pytest.raises(ValueError, match="Embedding API is not supported"),
):
vllm_model.embed([TEXTS_1[0]])
def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace."""
# -----------------------------------------------------------------------
# Per-model HuggingFace comparison tests
# -----------------------------------------------------------------------
def _assert_embeddings_close(vllm_outputs, hf_embeddings):
"""Assert that vLLM and HuggingFace embeddings match."""
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
vllm_emb = torch.tensor(vllm_out).float()
assert hf_emb.shape == vllm_emb.shape, (
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
)
torch.testing.assert_close(
vllm_emb,
hf_emb,
rtol=1e-2,
atol=1e-2,
msg=f"Embedding mismatch for text {i}",
)
def test_colbert_hf_comparison_bert(vllm_runner):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace (BERT)."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, BertModel
model_name = COLBERT_MODELS["bert"]["model"]
test_texts = [TEXTS_1[0], TEXTS_2[0]]
# Get vLLM embeddings first (to avoid GPU memory contention)
# Use fp32 to match HuggingFace default precision for fair comparison
with vllm_runner(
colbert_model_name,
model_name,
runner="pooling",
dtype="float32",
max_model_len=512,
@@ -196,14 +301,11 @@ def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
# Get HuggingFace reference embeddings on CPU
# Load the base BERT model and manually apply the ColBERT linear projection
hf_tokenizer = AutoTokenizer.from_pretrained(colbert_model_name)
hf_bert = BertModel.from_pretrained(colbert_model_name)
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_bert = BertModel.from_pretrained(model_name)
hf_bert.eval()
# Load the ColBERT linear weights from safetensors
weights_path = hf_hub_download(colbert_model_name, filename="model.safetensors")
weights_path = hf_hub_download(model_name, filename="model.safetensors")
weights = load_file(weights_path)
linear_weight = weights["linear.weight"] # [96, 384]
@@ -212,36 +314,103 @@ def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_bert(**inputs)
# Get last hidden state: [1, seq_len, 384]
hidden_states = outputs.last_hidden_state
# Apply ColBERT linear projection: [1, seq_len, 96]
token_emb = F.linear(hidden_states, linear_weight)
# L2 normalize
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
# Compare embeddings
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
vllm_emb = torch.tensor(vllm_out).float()
_assert_embeddings_close(vllm_outputs, hf_embeddings)
# Print first few components for debugging
print(f"\n=== Text {i}: '{test_texts[i][:30]}...' ===")
print(f"HF shape: {hf_emb.shape}, vLLM shape: {vllm_emb.shape}")
print(f"HF first token, first 10 dims: {hf_emb[0, :10].tolist()}")
print(f"vLLM first token, first 10 dims: {vllm_emb[0, :10].tolist()}")
print(f"HF last token, first 10 dims: {hf_emb[-1, :10].tolist()}")
print(f"vLLM last token, first 10 dims: {vllm_emb[-1, :10].tolist()}")
# Should have same shape
assert hf_emb.shape == vllm_emb.shape, (
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
)
def test_colbert_hf_comparison_modernbert(vllm_runner):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace
(ModernBERT)."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
# Should have same values (with tolerance for fp16)
torch.testing.assert_close(
vllm_emb,
hf_emb,
rtol=1e-2,
atol=1e-2,
msg=f"Embedding mismatch for text {i}",
)
spec = COLBERT_MODELS["modernbert"]
model_name = spec["model"]
test_texts = [TEXTS_1[0], TEXTS_2[0]]
with vllm_runner(
model_name,
runner="pooling",
dtype="float32",
max_model_len=spec["max_model_len"],
enforce_eager=True,
**spec["extra_kwargs"],
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModel.from_pretrained(model_name)
hf_model.eval()
# Load projection from sentence-transformers 1_Dense layer
dense_path = hf_hub_download(model_name, filename="1_Dense/model.safetensors")
dense_weights = load_file(dense_path)
linear_weight = dense_weights["linear.weight"] # [128, 768]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_model(**inputs)
hidden_states = outputs.last_hidden_state
token_emb = F.linear(hidden_states, linear_weight)
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
_assert_embeddings_close(vllm_outputs, hf_embeddings)
def test_colbert_hf_comparison_jina(vllm_runner):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace
(Jina XLM-RoBERTa)."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
spec = COLBERT_MODELS["jina"]
model_name = spec["model"]
test_texts = [TEXTS_1[0], TEXTS_2[0]]
with vllm_runner(
model_name,
runner="pooling",
dtype="float32",
max_model_len=spec["max_model_len"],
enforce_eager=True,
**spec["extra_kwargs"],
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
hf_model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
)
hf_model.eval()
# Load projection from main checkpoint
weights_path = hf_hub_download(model_name, filename="model.safetensors")
weights = load_file(weights_path)
linear_weight = weights["linear.weight"] # [128, 1024]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_model(**inputs)
hidden_states = outputs.last_hidden_state
token_emb = F.linear(hidden_states.float(), linear_weight.float())
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
_assert_embeddings_close(vllm_outputs, hf_embeddings)

View File

@@ -529,6 +529,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
"ColBERTModernBertModel": _HfExamplesInfo(
"lightonai/GTE-ModernColBERT-v1",
hf_overrides={"architectures": ["ColBERTModernBertModel"]},
),
"ColBERTJinaRobertaModel": _HfExamplesInfo(
"jinaai/jina-colbert-v2",
trust_remote_code=True,
hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]},
),
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),

View File

@@ -6,6 +6,14 @@ ColBERT late interaction model for retrieval and reranking.
ColBERT uses per-token embeddings and late interaction (MaxSim) scoring
instead of single-vector representations or cross-encoder concatenation.
This module provides:
- :class:`ColBERTMixin` — mixin that adds ColBERT late-interaction support
to any embedding model.
- :class:`ColBERTModel` — ColBERT with BERT backbone (original architecture).
- :class:`ColBERTModernBertModel` — ColBERT with ModernBERT backbone.
- :class:`ColBERTJinaRobertaModel` — ColBERT with Jina XLM-RoBERTa backbone.
Reference: https://arxiv.org/abs/2004.12832
"""
@@ -23,51 +31,60 @@ from .bert import BertEmbeddingModel, BertModel
from .interfaces_base import default_pooling_type
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTModel(BertEmbeddingModel):
"""ColBERT late interaction model for retrieval/reranking.
class ColBERTMixin:
"""Mixin that adds ColBERT late interaction support to any embedding model.
This model extends BertEmbeddingModel with a ColBERT-style linear
projection layer for per-token embeddings. It supports only:
- "token_embed" task: Per-token embeddings for late interaction
ColBERT (Contextualized Late Interaction over BERT) uses per-token
embeddings with a linear projection layer. This mixin provides:
ColBERT is fundamentally a per-token embedding model - the linear
projection is trained for per-token representations, not for CLS
pooling. Use a dedicated dense embedding model if you need single-
vector representations.
- ``supports_late_interaction`` class-var
- ColBERT linear projection initialisation / lazy creation
- Weight loading helpers for the projection layer
- A builder for the token-embedding pooler
The ColBERT scoring (MaxSim) is computed externally, either client-side
or via the late interaction scoring path in ServingScores.
**Integration:**
Attributes:
colbert_linear: Linear projection from hidden_size to colbert_dim
supports_late_interaction: Flag indicating this model uses late
interaction scoring
1. Inherit from both ``ColBERTMixin`` and ``nn.Module``.
2. In ``__init__``: call ``super().__init__()``, then
:meth:`_init_colbert_components`, then create ``self.model``
(the backbone) and ``self.pooler`` via :meth:`_build_colbert_pooler`.
3. In ``load_weights``: use :meth:`_load_colbert_weights` to separate
the ColBERT projection weight, then delegate the rest to the backbone.
"""
# Mark this model as supporting late interaction scoring
supports_late_interaction: ClassVar[Literal[True]] = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Get config before calling super().__init__
config = vllm_config.model_config.hf_config
self.hidden_size = config.hidden_size
self.head_dtype = vllm_config.model_config.head_dtype
# Set during _init_colbert_components
colbert_dim: int | None
colbert_linear: nn.Linear | None
hidden_size: int
head_dtype: torch.dtype
# ColBERT dimension - check various config field names used by different
# ColBERT implementations. If not found in config, will be inferred
# from loaded weights in load_weights()
self.colbert_dim: int | None = (
getattr(config, "colbert_dim", None)
or getattr(config, "dim", None)
or getattr(config, "projection_dim", None)
)
# ------------------------------------------------------------------ init
# Initialize parent (this will call _build_pooler)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _init_colbert_components(
self,
hidden_size: int,
colbert_dim: int | None,
head_dtype: torch.dtype,
) -> None:
"""Initialise ColBERT projection layer.
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config, prefix=prefix)
Args:
hidden_size: Hidden dimension of the encoder backbone.
colbert_dim: Output dimension for ColBERT embeddings. If
``None``, will be inferred from weights during loading (or
auto-loaded from sentence-transformers Dense layers).
head_dtype: Data type for the projection layer.
"""
self.hidden_size = hidden_size
self.colbert_dim = colbert_dim
self.head_dtype = head_dtype
if colbert_dim is not None:
self.colbert_linear = self._build_colbert_linear()
else:
self.colbert_linear = None
def _build_colbert_linear(self) -> nn.Linear:
"""Build the ColBERT linear projection layer."""
@@ -80,24 +97,127 @@ class ColBERTModel(BertEmbeddingModel):
dtype=self.head_dtype,
)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
# ColBERT linear projection: hidden_size -> colbert_dim
# Original ColBERT uses bias=False
# If colbert_dim is not set from config, it will be inferred during
# load_weights and the linear layer will be created there
if self.colbert_dim is not None:
self.colbert_linear = self._build_colbert_linear()
else:
# Placeholder - will be created when weights are loaded
self.colbert_linear = None
# ---------------------------------------------------------------- pooler
# ColBERT only supports token_embed - it's fundamentally a per-token
# embedding model.
def _build_colbert_pooler(self, pooler_config: PoolerConfig) -> Pooler:
"""Build pooler for ColBERT token embeddings.
When ``colbert_linear`` is set, it is used as the projector.
Otherwise ``pooler_for_token_embed`` falls back to auto-loading
sentence-transformers Dense layers (``1_Dense/`` etc.).
"""
return pooler_for_token_embed(
pooler_config,
projector=self.colbert_linear,
)
# --------------------------------------------------------- config helper
@classmethod
def get_colbert_dim_from_config(cls, hf_config) -> int | None:
"""Extract ColBERT dimension from a HuggingFace config.
Checks ``colbert_dim``, ``dim`` and ``projection_dim`` in that order.
"""
return (
getattr(hf_config, "colbert_dim", None)
or getattr(hf_config, "dim", None)
or getattr(hf_config, "projection_dim", None)
)
# -------------------------------------------------------- weight loading
def _load_colbert_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
colbert_weight_names: tuple[str, ...] = (
"linear.weight",
"colbert_linear.weight",
),
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
"""Separate and load ColBERT projection weights.
Scans *weights* for entries whose name ends with one of
*colbert_weight_names*. The matching weight is loaded into
``self.colbert_linear`` (creating it first if ``colbert_dim`` was
not known at init time).
Args:
weights: Iterable of ``(name, tensor)`` weight pairs.
colbert_weight_names: Suffixes that identify the ColBERT linear
weight.
Returns:
``(remaining_weights, loaded_names)`` — the weights that were
**not** consumed and the set of names that were loaded.
"""
weights_list = list(weights)
other_weights: list[tuple[str, torch.Tensor]] = []
colbert_weight: tuple[str, torch.Tensor] | None = None
for name, weight in weights_list:
if any(name.endswith(cw) for cw in colbert_weight_names):
colbert_weight = (name, weight)
else:
other_weights.append((name, weight))
loaded: set[str] = set()
if colbert_weight is not None:
_name, weight = colbert_weight
if weight.dim() == 2:
# Infer colbert_dim from weight shape if not set
if self.colbert_dim is None:
self.colbert_dim = weight.shape[0]
self.colbert_linear = self._build_colbert_linear()
# Update the pooler's projector
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
self.pooler.head.projector = self.colbert_linear
assert self.colbert_linear is not None
# Move to same device as model
if hasattr(self, "model"):
device = next(self.model.parameters()).device
self.colbert_linear.to(device)
weight = weight.to(self.colbert_linear.weight.device)
self.colbert_linear.weight.data.copy_(weight)
loaded.add("pooler.head.projector.weight")
return other_weights, loaded
# -----------------------------------------------------------------------
# Concrete model: ColBERT + BERT backbone (original architecture)
# -----------------------------------------------------------------------
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTModel(ColBERTMixin, BertEmbeddingModel):
"""ColBERT late interaction model with BERT backbone.
Supports the ``token_embed`` task (per-token embeddings for late
interaction). MaxSim scoring is computed externally.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
# Must run before super().__init__ because _build_pooler reads these.
colbert_dim = self.get_colbert_dim_from_config(config)
self._init_colbert_components(
hidden_size=config.hidden_size,
colbert_dim=colbert_dim,
head_dtype=vllm_config.model_config.head_dtype,
)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config, prefix=prefix)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return self._build_colbert_pooler(pooler_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def _strip(name: str) -> str:
for p in ("model.", "bert."):
@@ -111,7 +231,7 @@ class ColBERTModel(BertEmbeddingModel):
for name, weight in weights_list:
stripped = _strip(name)
# Handle different checkpoint naming conventions for ColBERT linear
# Handle different checkpoint naming conventions
if stripped in ("linear.weight", "colbert_linear.weight"):
colbert_side.append(("colbert_linear.weight", weight))
elif stripped.startswith("linear.") or stripped.startswith(
@@ -122,31 +242,178 @@ class ColBERTModel(BertEmbeddingModel):
else:
model_side.append((stripped, weight))
# Load base BERT weights using BertModel.load_weights which handles QKV fusion
loaded: set[str] = set()
loaded_model = self.model.load_weights(model_side)
loaded.update({"model." + n for n in loaded_model})
# Load ColBERT linear weights
if colbert_side:
for name, weight in colbert_side:
if name == "colbert_linear.weight":
# Infer colbert_dim from weights if not set in config
if self.colbert_dim is None:
# Weight shape is [colbert_dim, hidden_size]
self.colbert_dim = weight.shape[0]
# Create the linear layer now that we know the dimension
self.colbert_linear = self._build_colbert_linear()
# Move to the same device as the model's existing parameters
device = next(self.model.parameters()).device
self.colbert_linear.to(device)
# Update the pooler's projector to use the new linear layer
self.pooler.head.projector = self.colbert_linear
# Load weights directly into the pooler's projector
weight = weight.to(self.pooler.head.projector.weight.device)
self.pooler.head.projector.weight.data.copy_(weight)
loaded.add("pooler.head.projector.weight")
break
_, colbert_loaded = self._load_colbert_weights(colbert_side)
loaded.update(colbert_loaded)
return loaded
# -----------------------------------------------------------------------
# Concrete model: ColBERT + ModernBERT backbone
# -----------------------------------------------------------------------
from .modernbert import ModernBertModel # noqa: E402
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTModernBertModel(ColBERTMixin, nn.Module):
"""ColBERT late interaction model with ModernBERT backbone.
For ``lightonai/GTE-ModernColBERT-v1`` and similar models.
The projection is auto-loaded from sentence-transformers ``1_Dense/``
when not present in the main checkpoint.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
colbert_dim = self.get_colbert_dim_from_config(config)
self._init_colbert_components(
hidden_size=config.hidden_size,
colbert_dim=colbert_dim,
head_dtype=vllm_config.model_config.head_dtype,
)
self.model = ModernBertModel(
vllm_config=vllm_config,
prefix=prefix,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = self._build_colbert_pooler(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors=None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
other_weights, colbert_loaded = self._load_colbert_weights(weights)
# Strip "model." prefix added by the embedding adapter
model_weights = [
(n[len("model.") :] if n.startswith("model.") else n, w)
for n, w in other_weights
]
loaded_model = self.model.load_weights(model_weights)
loaded = {"model." + n for n in loaded_model} | colbert_loaded
# When the ST projector was auto-loaded during init
# (not from the main checkpoint), mark its params as loaded
# so the weight validator doesn't complain.
if hasattr(self.pooler, "head"):
head = self.pooler.head
projector = getattr(head, "projector", None)
if projector is not None and isinstance(projector, nn.Module):
for name, _ in projector.named_parameters():
loaded.add(f"pooler.head.projector.{name}")
return loaded
# -----------------------------------------------------------------------
# Concrete model: ColBERT + Jina XLM-RoBERTa backbone
# -----------------------------------------------------------------------
from .bert_with_rope import JinaRobertaModel # noqa: E402
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTJinaRobertaModel(ColBERTMixin, nn.Module):
"""ColBERT late interaction model with Jina XLM-RoBERTa backbone.
For ``jinaai/jina-colbert-v2`` and similar models.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
colbert_dim = self.get_colbert_dim_from_config(config)
self._init_colbert_components(
hidden_size=config.hidden_size,
colbert_dim=colbert_dim,
head_dtype=vllm_config.model_config.head_dtype,
)
self.model = JinaRobertaModel(
vllm_config=vllm_config,
prefix=prefix,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = self._build_colbert_pooler(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors=None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)
model_side: list[tuple[str, torch.Tensor]] = []
colbert_side: list[tuple[str, torch.Tensor]] = []
for name, weight in weights_list:
stripped = name
# Strip "model." prefix added by the embedding adapter
if stripped.startswith("model."):
stripped = stripped[len("model.") :]
# Strip "roberta." prefix from checkpoint
if stripped.startswith("roberta."):
stripped = stripped[len("roberta.") :]
if stripped in ("linear.weight", "colbert_linear.weight"):
colbert_side.append(("colbert_linear.weight", weight))
elif stripped.startswith("pooler."):
# Skip HF pooler weights (not used in ColBERT)
continue
else:
model_side.append((stripped, weight))
loaded: set[str] = set()
loaded_model = self.model.load_weights(model_side)
loaded.update({"model." + n for n in loaded_model})
if colbert_side:
_, colbert_loaded = self._load_colbert_weights(colbert_side)
loaded.update(colbert_loaded)
return loaded

View File

@@ -629,6 +629,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
"XLMRobertaModel": JinaRobertaModelConfig,
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"GptOssForCausalLM": GptOssForCausalLMConfig,

View File

@@ -208,6 +208,8 @@ _EMBEDDING_MODELS = {
"BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"HF_ColBERT": ("colbert", "ColBERTModel"),
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"),

View File

@@ -1068,9 +1068,11 @@ def try_get_dense_modules(
if isinstance(modules, dict):
modules = modules.get("modules", [])
dense_modules = [
m for m in modules if m.get("type") == "sentence_transformers.models.Dense"
]
_DENSE_MODULE_TYPES = {
"sentence_transformers.models.Dense",
"pylate.models.Dense.Dense",
}
dense_modules = [m for m in modules if m.get("type") in _DENSE_MODULE_TYPES]
if not dense_modules:
return None