feat: Add ColBERT late interaction model support (#33686)

Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com>
Signed-off-by: Ilya Boytsov <boytsovpanamera@mail.ru>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Ilya Boytsov
2026-02-05 01:05:13 +01:00
committed by GitHub
parent fa4e0fb028
commit 439afa4eea
13 changed files with 974 additions and 3 deletions

View File

@@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreMultiModalParam,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
get_score_prompt,
validate_score_input,
)
@@ -1368,6 +1369,87 @@ class LLM:
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _late_interaction_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
*,
use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any],
) -> list[ScoringRequestOutput]:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
from vllm.outputs import PoolingOutput
tokenizer = self.get_tokenizer()
# Extract text from ScoreData
text_1: list[str] = []
for text in data_1:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_1.append(text)
text_2: list[str] = []
for text in data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_2.append(text)
encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
pooling_task="token_embed",
tokenization_kwargs=tokenization_kwargs,
)
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
# Compute MaxSim scores
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=PoolingOutput(data=maxsim_score),
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _cross_encoding_score(
self,
data_1: list[ScoreData],
@@ -1497,7 +1579,11 @@ class LLM:
)
supported_tasks = self.supported_tasks
if all(t not in supported_tasks for t in ("embed", "classify")):
# Late interaction models (e.g., ColBERT) use token_embed for scoring
is_late_interaction = model_config.is_late_interaction
if not is_late_interaction and all(
t not in supported_tasks for t in ("embed", "classify")
):
raise ValueError(
"Score API is not supported by this model. "
"Try converting the model using "
@@ -1538,6 +1624,15 @@ class LLM:
tokenization_kwargs=encode_kwargs,
score_template=chat_template,
)
elif is_late_interaction:
return self._late_interaction_score(
score_data_1,
score_data_2,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=encode_kwargs,
)
else:
return self._embedding_score(
score_data_1,

View File

@@ -37,7 +37,11 @@ def register_pooling_api_routers(
app.include_router(embed_router)
if "score" in supported_tasks or "embed" in supported_tasks:
# Score/rerank endpoints are available for:
# - "score" task (cross-encoder models)
# - "embed" task (bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT)
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(score_router)
@@ -101,6 +105,10 @@ def init_pooling_state(
if "classify" in supported_tasks
else None
)
# ServingScores handles score/rerank for:
# - "score" task (cross-encoder models)
# - "embed" task (bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT)
state.openai_serving_scores = (
ServingScores(
engine_client,
@@ -109,6 +117,6 @@ def init_pooling_state(
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
)
if ("embed" in supported_tasks or "score" in supported_tasks)
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None
)

View File

@@ -31,6 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
get_score_prompt,
validate_score_input,
)
@@ -68,9 +69,12 @@ class ServingScores(OpenAIServing):
self.is_cross_encoder = self.model_config.is_cross_encoder
self.is_multimodal_model = self.model_config.is_multimodal_model
self.architecture = self.model_config.architecture
self.is_late_interaction = self.model_config.is_late_interaction
if self.is_cross_encoder:
self._score_func = self._cross_encoding_score
elif self.is_late_interaction:
self._score_func = self._late_interaction_score
else:
self._score_func = self._embedding_score
@@ -172,6 +176,142 @@ class ServingScores(OpenAIServing):
return final_res_batch
async def _late_interaction_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
request: RerankRequest | ScoreRequest,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
input_texts: list[str] = []
for text in data_1 + data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
input_texts.append(text)
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
encode_async = make_async(
tokenizer.encode,
executor=self._tokenizer_executor,
)
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
tokenized_prompts = await asyncio.gather(
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
# Use token_embed task for late interaction models
from vllm import PoolingParams
pooling_params = PoolingParams(
task="token_embed",
truncate_prompt_tokens=request.truncate_prompt_tokens,
use_activation=request.use_activation,
)
try:
pooling_params.verify("token_embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
input_texts[i],
params=pooling_params,
lora_request=lora_request,
)
generators.append(
self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
)
result_generator = merge_async_iterators(*generators)
# Collect token embeddings
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
async for i, res in result_generator:
embeddings[i] = res
# Split into query and document embeddings
emb_data_1: list[PoolingRequestOutput] = []
emb_data_2: list[PoolingRequestOutput] = []
for i in range(0, len(data_1)):
assert (emb := embeddings[i]) is not None
emb_data_1.append(emb)
for i in range(len(data_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_data_2.append(emb)
# Expand queries if 1:N scoring
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=PoolingOutput(data=maxsim_score),
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return scores
async def _cross_encoding_score(
self,
data_1: list[ScoreData],

View File

@@ -3,6 +3,7 @@
from collections.abc import Iterable
from typing import Any, TypeAlias, cast
import torch
from torch.nn import CosineSimilarity
from typing_extensions import Required, TypedDict
@@ -34,6 +35,23 @@ ScoreContentPartParam: TypeAlias = (
)
def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor:
"""
Compute ColBERT MaxSim score.
Args:
q_emb: Query token embeddings [query_len, dim]
d_emb: Document token embeddings [doc_len, dim]
Returns:
MaxSim score (sum over query tokens of max similarity to any doc token)
"""
# [query_len, doc_len]
token_scores = torch.matmul(q_emb, d_emb.T)
# Max over document tokens, sum over query tokens
return token_scores.amax(dim=-1).sum()
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content