feat: Add ColBERT late interaction model support (#33686)

Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com>
Signed-off-by: Ilya Boytsov <boytsovpanamera@mail.ru>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Ilya Boytsov
2026-02-05 01:05:13 +01:00
committed by GitHub
parent fa4e0fb028
commit 439afa4eea
13 changed files with 974 additions and 3 deletions

View File

@@ -307,6 +307,62 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed
## Specific models
### ColBERT Late Interaction Models
[ColBERT](https://arxiv.org/abs/2004.12832) (Contextualized Late Interaction over BERT) is a retrieval model that uses per-token embeddings and MaxSim scoring for document ranking. Unlike single-vector embedding models, ColBERT retains token-level representations and computes relevance scores through late interaction, providing better accuracy while being more efficient than cross-encoders.
vLLM supports ColBERT models for reranking tasks, automatically applying MaxSim scoring for query-document relevance:
```shell
vllm serve answerdotai/answerai-colbert-small-v1
```
Currently supports ColBERT models with standard BERT encoders (e.g., `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0`).
ColBERT models with modified encoder architectures are not yet supported, including BERT variants with rotary embeddings (e.g., `jinaai/jina-colbert-v2`) or other custom encoders (e.g., `LiquidAI/LFM2-ColBERT-350M`).
If your standard BERT ColBERT model's config doesn't specify the architecture as `HF_ColBERT`, override it with:
```shell
vllm serve your-colbert-model --hf-overrides '{"architectures": ["HF_ColBERT"]}'
```
Then you can use the rerank endpoint:
```shell
curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
"model": "answerdotai/answerai-colbert-small-v1",
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks."
]
}'
```
Or the score endpoint:
```shell
curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
"model": "answerdotai/answerai-colbert-small-v1",
"text_1": "What is machine learning?",
"text_2": ["Machine learning is a subset of AI.", "The weather is sunny."]
}'
```
You can also get the raw token embeddings using the pooling endpoint with `token_embed` task:
```shell
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
"model": "answerdotai/answerai-colbert-small-v1",
"input": "What is machine learning?",
"task": "token_embed"
}'
```
An example can be found here: [examples/pooling/score/colbert_rerank_online.py](../../examples/pooling/score/colbert_rerank_online.py)
### BAAI/bge-m3
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using ColBERT late interaction model for reranking.
ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings
and MaxSim scoring for document reranking, providing better accuracy than
single-vector models while being more efficient than cross-encoders.
Start the server with:
vllm serve answerdotai/answerai-colbert-small-v1
Then run this script:
python colbert_rerank_online.py
"""
import json
import requests
url = "http://127.0.0.1:8000/rerank"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {
"model": "answerdotai/answerai-colbert-small-v1",
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
],
}
def main():
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
print("ColBERT Rerank Request successful!")
result = response.json()
print(json.dumps(result, indent=2))
# Show ranked results
print("\nRanked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" Score {score:.4f}: {data['documents'][doc_idx]}")
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Online API tests for ColBERT late interaction scoring."""
import pytest
import requests
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
MODEL_NAME = "answerdotai/answerai-colbert-small-v1"
COLBERT_DIM = 96 # This model uses 96-dimensional output
DTYPE = "half"
MAX_MODEL_LEN = 512
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
str(MAX_MODEL_LEN),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_rerank(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT rerank endpoint."""
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]
rerank_response = requests.post(
server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents,
},
)
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 2
# The relevant document (Paris) should have higher score
paris_result = next(r for r in rerank.results if r.index == 1)
brazil_result = next(r for r in rerank.results if r.index == 0)
assert paris_result.relevance_score > brazil_result.relevance_score
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_rerank_top_n(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT rerank with top_n parameter."""
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
"Machine learning is a field of AI.",
]
rerank_response = requests.post(
server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents,
"top_n": 2,
},
)
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())
assert len(rerank.results) == 2
# Top result should be about Paris (index 1)
assert rerank.results[0].index == 1
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_score(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT score endpoint."""
text_1 = "What is the capital of France?"
text_2 = ["The capital of France is Paris.", "Python is a language."]
score_response = requests.post(
server.url_for("score"),
json={
"model": model_name,
"text_1": text_1,
"text_2": text_2,
},
)
score_response.raise_for_status()
score = ScoreResponse.model_validate(score_response.json())
assert score.id is not None
assert score.data is not None
assert len(score.data) == 2
# The relevant document should have higher score
assert score.data[0].score > score.data[1].score
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT token_embed task via pooling endpoint."""
text = "What is the capital of France?"
pooling_response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": "token_embed",
},
)
pooling_response.raise_for_status()
pooling = pooling_response.json()
assert "data" in pooling
assert len(pooling["data"]) == 1
# Token embeddings should be 2D
embeddings = pooling["data"][0]["data"]
assert isinstance(embeddings, list)
assert len(embeddings) > 0 # Should have tokens
assert len(embeddings[0]) == COLBERT_DIM
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
"""Test that ColBERT model does not support 'embed' task."""
text = "What is the capital of France?"
pooling_response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": "embed",
},
)
# Should return error
assert pooling_response.status_code == 400
assert "Task embed is not supported" in pooling_response.text

View File

@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ColBERT late interaction scoring."""
import pytest
import torch
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
# suitable for testing (based on BERT-base)
COLBERT_MODEL = "answerdotai/answerai-colbert-small-v1"
COLBERT_DIM = 96 # This model uses 96-dimensional output
TEXTS_1 = [
"What is the capital of France?",
"What is the capital of Germany?",
]
TEXTS_2 = [
"The capital of France is Paris.",
"The capital of Germany is Berlin.",
]
DTYPE = "half"
@pytest.fixture(scope="module")
def colbert_model_name():
return COLBERT_MODEL
def test_colbert_token_embed(vllm_runner, colbert_model_name):
"""Test that ColBERT model produces token embeddings."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
# Get token embeddings for a single text
outputs = vllm_model.token_embed([TEXTS_1[0]])
assert len(outputs) == 1
# Token embeddings should be 2D: [num_tokens, colbert_dim]
emb = torch.tensor(outputs[0])
assert emb.dim() == 2
assert emb.shape[1] == COLBERT_DIM
# Should have at least a few tokens
assert emb.shape[0] > 1
def test_colbert_late_interaction_1_to_1(vllm_runner, colbert_model_name):
"""Test ColBERT late interaction scoring with 1:1 query-document pair."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed([TEXTS_2[0]])
q_emb = torch.tensor(q_outputs[0])
d_emb = torch.tensor(d_outputs[0])
# Compute MaxSim manually
manual_score = compute_maxsim_score(q_emb, d_emb).item()
# Use the score API (which should internally use _late_interaction_score)
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])
assert len(vllm_scores) == 1
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name):
"""Test ColBERT late interaction scoring with 1:N query-documents."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed(TEXTS_2)
q_emb = torch.tensor(q_outputs[0])
# Compute MaxSim manually for each document
manual_scores = []
for d_out in d_outputs:
d_emb = torch.tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
# Use the score API
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)
assert len(vllm_scores) == 2
for i in range(2):
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name):
"""Test ColBERT late interaction scoring with N:N query-documents."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed(TEXTS_1)
d_outputs = vllm_model.token_embed(TEXTS_2)
# Compute MaxSim manually for each pair
manual_scores = []
for q_out, d_out in zip(q_outputs, d_outputs):
q_emb = torch.tensor(q_out)
d_emb = torch.tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
# Use the score API
vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)
assert len(vllm_scores) == 2
for i in range(2):
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
def test_colbert_relevance_ordering(vllm_runner, colbert_model_name):
"""Test that ColBERT scores relevant documents higher than irrelevant ones."""
query = "What is machine learning?"
documents = [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks.",
]
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
scores = vllm_model.score(query, documents)
assert len(scores) == 3
# ML-related documents should score higher than unrelated Python doc
# Document 0 (ML definition) should be most relevant
# Document 2 (Deep learning) should also be relevant
# Document 1 (Python) should be least relevant
assert scores[0] > scores[1], "ML doc should score higher than Python doc"
assert scores[2] > scores[1], "DL doc should score higher than Python doc"
def test_colbert_embed_not_supported(vllm_runner, colbert_model_name):
"""Test that ColBERT model does not support 'embed' task."""
with (
vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model,
pytest.raises(ValueError, match="Embedding API is not supported"),
):
vllm_model.embed([TEXTS_1[0]])
def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, BertModel
test_texts = [TEXTS_1[0], TEXTS_2[0]]
# Get vLLM embeddings first (to avoid GPU memory contention)
# Use fp32 to match HuggingFace default precision for fair comparison
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype="float32",
max_model_len=512,
enforce_eager=True,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
# Get HuggingFace reference embeddings on CPU
# Load the base BERT model and manually apply the ColBERT linear projection
hf_tokenizer = AutoTokenizer.from_pretrained(colbert_model_name)
hf_bert = BertModel.from_pretrained(colbert_model_name)
hf_bert.eval()
# Load the ColBERT linear weights from safetensors
weights_path = hf_hub_download(colbert_model_name, filename="model.safetensors")
weights = load_file(weights_path)
linear_weight = weights["linear.weight"] # [96, 384]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_bert(**inputs)
# Get last hidden state: [1, seq_len, 384]
hidden_states = outputs.last_hidden_state
# Apply ColBERT linear projection: [1, seq_len, 96]
token_emb = F.linear(hidden_states, linear_weight)
# L2 normalize
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
# Compare embeddings
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
vllm_emb = torch.tensor(vllm_out).float()
# Print first few components for debugging
print(f"\n=== Text {i}: '{test_texts[i][:30]}...' ===")
print(f"HF shape: {hf_emb.shape}, vLLM shape: {vllm_emb.shape}")
print(f"HF first token, first 10 dims: {hf_emb[0, :10].tolist()}")
print(f"vLLM first token, first 10 dims: {vllm_emb[0, :10].tolist()}")
print(f"HF last token, first 10 dims: {hf_emb[-1, :10].tolist()}")
print(f"vLLM last token, first 10 dims: {vllm_emb[-1, :10].tolist()}")
# Should have same shape
assert hf_emb.shape == vllm_emb.shape, (
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
)
# Should have same values (with tolerance for fp16)
torch.testing.assert_close(
vllm_emb,
hf_emb,
rtol=1e-2,
atol=1e-2,
msg=f"Embedding mismatch for text {i}",
)

View File

@@ -520,6 +520,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),

View File

@@ -1411,6 +1411,11 @@ class ModelConfig:
self._model_info.supports_cross_encoding or self.convert_type == "classify"
)
@property
def is_late_interaction(self) -> bool:
"""Check if model uses late interaction (ColBERT-style) scoring."""
return self._model_info.supports_late_interaction
@property
def is_pp_supported(self) -> bool:
return self._model_info.supports_pp

View File

@@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreMultiModalParam,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
get_score_prompt,
validate_score_input,
)
@@ -1368,6 +1369,87 @@ class LLM:
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _late_interaction_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
*,
use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any],
) -> list[ScoringRequestOutput]:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
from vllm.outputs import PoolingOutput
tokenizer = self.get_tokenizer()
# Extract text from ScoreData
text_1: list[str] = []
for text in data_1:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_1.append(text)
text_2: list[str] = []
for text in data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_2.append(text)
encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
pooling_task="token_embed",
tokenization_kwargs=tokenization_kwargs,
)
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
# Compute MaxSim scores
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=PoolingOutput(data=maxsim_score),
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _cross_encoding_score(
self,
data_1: list[ScoreData],
@@ -1497,7 +1579,11 @@ class LLM:
)
supported_tasks = self.supported_tasks
if all(t not in supported_tasks for t in ("embed", "classify")):
# Late interaction models (e.g., ColBERT) use token_embed for scoring
is_late_interaction = model_config.is_late_interaction
if not is_late_interaction and all(
t not in supported_tasks for t in ("embed", "classify")
):
raise ValueError(
"Score API is not supported by this model. "
"Try converting the model using "
@@ -1538,6 +1624,15 @@ class LLM:
tokenization_kwargs=encode_kwargs,
score_template=chat_template,
)
elif is_late_interaction:
return self._late_interaction_score(
score_data_1,
score_data_2,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=encode_kwargs,
)
else:
return self._embedding_score(
score_data_1,

View File

@@ -37,7 +37,11 @@ def register_pooling_api_routers(
app.include_router(embed_router)
if "score" in supported_tasks or "embed" in supported_tasks:
# Score/rerank endpoints are available for:
# - "score" task (cross-encoder models)
# - "embed" task (bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT)
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(score_router)
@@ -101,6 +105,10 @@ def init_pooling_state(
if "classify" in supported_tasks
else None
)
# ServingScores handles score/rerank for:
# - "score" task (cross-encoder models)
# - "embed" task (bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT)
state.openai_serving_scores = (
ServingScores(
engine_client,
@@ -109,6 +117,6 @@ def init_pooling_state(
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
)
if ("embed" in supported_tasks or "score" in supported_tasks)
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None
)

View File

@@ -31,6 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
get_score_prompt,
validate_score_input,
)
@@ -68,9 +69,12 @@ class ServingScores(OpenAIServing):
self.is_cross_encoder = self.model_config.is_cross_encoder
self.is_multimodal_model = self.model_config.is_multimodal_model
self.architecture = self.model_config.architecture
self.is_late_interaction = self.model_config.is_late_interaction
if self.is_cross_encoder:
self._score_func = self._cross_encoding_score
elif self.is_late_interaction:
self._score_func = self._late_interaction_score
else:
self._score_func = self._embedding_score
@@ -172,6 +176,142 @@ class ServingScores(OpenAIServing):
return final_res_batch
async def _late_interaction_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
request: RerankRequest | ScoreRequest,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
input_texts: list[str] = []
for text in data_1 + data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
input_texts.append(text)
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
encode_async = make_async(
tokenizer.encode,
executor=self._tokenizer_executor,
)
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
tokenized_prompts = await asyncio.gather(
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
# Use token_embed task for late interaction models
from vllm import PoolingParams
pooling_params = PoolingParams(
task="token_embed",
truncate_prompt_tokens=request.truncate_prompt_tokens,
use_activation=request.use_activation,
)
try:
pooling_params.verify("token_embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
input_texts[i],
params=pooling_params,
lora_request=lora_request,
)
generators.append(
self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
)
result_generator = merge_async_iterators(*generators)
# Collect token embeddings
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
async for i, res in result_generator:
embeddings[i] = res
# Split into query and document embeddings
emb_data_1: list[PoolingRequestOutput] = []
emb_data_2: list[PoolingRequestOutput] = []
for i in range(0, len(data_1)):
assert (emb := embeddings[i]) is not None
emb_data_1.append(emb)
for i in range(len(data_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_data_2.append(emb)
# Expand queries if 1:N scoring
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=PoolingOutput(data=maxsim_score),
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return scores
async def _cross_encoding_score(
self,
data_1: list[ScoreData],

View File

@@ -3,6 +3,7 @@
from collections.abc import Iterable
from typing import Any, TypeAlias, cast
import torch
from torch.nn import CosineSimilarity
from typing_extensions import Required, TypedDict
@@ -34,6 +35,23 @@ ScoreContentPartParam: TypeAlias = (
)
def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor:
"""
Compute ColBERT MaxSim score.
Args:
q_emb: Query token embeddings [query_len, dim]
d_emb: Document token embeddings [doc_len, dim]
Returns:
MaxSim score (sum over query tokens of max similarity to any doc token)
"""
# [query_len, doc_len]
token_scores = torch.matmul(q_emb, d_emb.T)
# Max over document tokens, sum over query tokens
return token_scores.amax(dim=-1).sum()
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content

View File

@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
ColBERT late interaction model for retrieval and reranking.
ColBERT uses per-token embeddings and late interaction (MaxSim) scoring
instead of single-vector representations or cross-encoder concatenation.
Reference: https://arxiv.org/abs/2004.12832
"""
from collections.abc import Iterable
from typing import ClassVar, Literal
import torch
from torch import nn
from vllm.config import PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from .bert import BertEmbeddingModel, BertModel
from .interfaces_base import default_pooling_type
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTModel(BertEmbeddingModel):
"""ColBERT late interaction model for retrieval/reranking.
This model extends BertEmbeddingModel with a ColBERT-style linear
projection layer for per-token embeddings. It supports only:
- "token_embed" task: Per-token embeddings for late interaction
ColBERT is fundamentally a per-token embedding model - the linear
projection is trained for per-token representations, not for CLS
pooling. Use a dedicated dense embedding model if you need single-
vector representations.
The ColBERT scoring (MaxSim) is computed externally, either client-side
or via the late interaction scoring path in ServingScores.
Attributes:
colbert_linear: Linear projection from hidden_size to colbert_dim
supports_late_interaction: Flag indicating this model uses late
interaction scoring
"""
# Mark this model as supporting late interaction scoring
supports_late_interaction: ClassVar[Literal[True]] = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Get config before calling super().__init__
config = vllm_config.model_config.hf_config
self.hidden_size = config.hidden_size
self.head_dtype = vllm_config.model_config.head_dtype
# ColBERT dimension - check various config field names used by different
# ColBERT implementations. If not found in config, will be inferred
# from loaded weights in load_weights()
self.colbert_dim: int | None = (
getattr(config, "colbert_dim", None)
or getattr(config, "dim", None)
or getattr(config, "projection_dim", None)
)
# Initialize parent (this will call _build_pooler)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config, prefix=prefix)
def _build_colbert_linear(self) -> nn.Linear:
"""Build the ColBERT linear projection layer."""
if self.colbert_dim is None:
raise ValueError("colbert_dim must be set before building the linear layer")
return nn.Linear(
self.hidden_size,
self.colbert_dim,
bias=False,
dtype=self.head_dtype,
)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
# ColBERT linear projection: hidden_size -> colbert_dim
# Original ColBERT uses bias=False
# If colbert_dim is not set from config, it will be inferred during
# load_weights and the linear layer will be created there
if self.colbert_dim is not None:
self.colbert_linear = self._build_colbert_linear()
else:
# Placeholder - will be created when weights are loaded
self.colbert_linear = None
# ColBERT only supports token_embed - it's fundamentally a per-token
# embedding model.
return pooler_for_token_embed(
pooler_config,
projector=self.colbert_linear,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def _strip(name: str) -> str:
for p in ("model.", "bert."):
if name.startswith(p):
name = name[len(p) :]
return name
weights_list = list(weights)
model_side: list[tuple[str, torch.Tensor]] = []
colbert_side: list[tuple[str, torch.Tensor]] = []
for name, weight in weights_list:
stripped = _strip(name)
# Handle different checkpoint naming conventions for ColBERT linear
if stripped in ("linear.weight", "colbert_linear.weight"):
colbert_side.append(("colbert_linear.weight", weight))
elif stripped.startswith("linear.") or stripped.startswith(
"colbert_linear."
):
new_name = stripped.replace("linear.", "colbert_linear.")
colbert_side.append((new_name, weight))
else:
model_side.append((stripped, weight))
# Load base BERT weights using BertModel.load_weights which handles QKV fusion
loaded: set[str] = set()
loaded_model = self.model.load_weights(model_side)
loaded.update({"model." + n for n in loaded_model})
# Load ColBERT linear weights
if colbert_side:
for name, weight in colbert_side:
if name == "colbert_linear.weight":
# Infer colbert_dim from weights if not set in config
if self.colbert_dim is None:
# Weight shape is [colbert_dim, hidden_size]
self.colbert_dim = weight.shape[0]
# Create the linear layer now that we know the dimension
self.colbert_linear = self._build_colbert_linear()
# Move to the same device as the model's existing parameters
device = next(self.model.parameters()).device
self.colbert_linear.to(device)
# Update the pooler's projector to use the new linear layer
self.pooler.head.projector = self.colbert_linear
# Load weights directly into the pooler's projector
weight = weight.to(self.pooler.head.projector.weight.device)
self.pooler.head.projector.weight.data.copy_(weight)
loaded.add("pooler.head.projector.weight")
break
return loaded

View File

@@ -981,6 +981,40 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model)
@runtime_checkable
class SupportsLateInteraction(Protocol):
"""The interface required for all models that support late interaction.
Late interaction models (like ColBERT) encode queries and documents
separately into per-token embeddings, then compute similarity via
MaxSim (max over document tokens, sum over query tokens).
"""
supports_late_interaction: ClassVar[Literal[True]] = True
@overload
def supports_late_interaction(
model: type[object],
) -> TypeIs[type[SupportsLateInteraction]]: ...
@overload
def supports_late_interaction(model: object) -> TypeIs[SupportsLateInteraction]: ...
def _supports_late_interaction(
model: type[object] | object,
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
return getattr(model, "supports_late_interaction", False)
def supports_late_interaction(
model: type[object] | object,
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
return is_pooling_model(model) and _supports_late_interaction(model)
class SupportsQuant:
"""The interface required for all models that support quantization."""

View File

@@ -49,6 +49,7 @@ from .interfaces import (
is_hybrid,
requires_raw_input_tokens,
supports_cross_encoding,
supports_late_interaction,
supports_mamba_prefix_caching,
supports_multimodal,
supports_multimodal_encoder_tp_data,
@@ -205,6 +206,7 @@ _EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"HF_ColBERT": ("colbert", "ColBERTModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
@@ -593,6 +595,7 @@ class _ModelInfo:
default_seq_pooling_type: SequencePoolingType
default_tok_pooling_type: TokenPoolingType
supports_cross_encoding: bool
supports_late_interaction: bool
supports_multimodal: bool
supports_multimodal_raw_input_only: bool
requires_raw_input_tokens: bool
@@ -616,6 +619,7 @@ class _ModelInfo:
default_tok_pooling_type=get_default_tok_pooling_type(model),
attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_late_interaction=supports_late_interaction(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
model