feat: Add ColBERT late interaction model support (#33686)

Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com>
Signed-off-by: Ilya Boytsov <boytsovpanamera@mail.ru>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Ilya Boytsov
2026-02-05 01:05:13 +01:00
committed by GitHub
parent fa4e0fb028
commit 439afa4eea
13 changed files with 974 additions and 3 deletions

View File

@@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreMultiModalParam,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
get_score_prompt,
validate_score_input,
)
@@ -1368,6 +1369,87 @@ class LLM:
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _late_interaction_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
*,
use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any],
) -> list[ScoringRequestOutput]:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
from vllm.outputs import PoolingOutput
tokenizer = self.get_tokenizer()
# Extract text from ScoreData
text_1: list[str] = []
for text in data_1:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_1.append(text)
text_2: list[str] = []
for text in data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_2.append(text)
encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
pooling_task="token_embed",
tokenization_kwargs=tokenization_kwargs,
)
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
# Compute MaxSim scores
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=PoolingOutput(data=maxsim_score),
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _cross_encoding_score(
self,
data_1: list[ScoreData],
@@ -1497,7 +1579,11 @@ class LLM:
)
supported_tasks = self.supported_tasks
if all(t not in supported_tasks for t in ("embed", "classify")):
# Late interaction models (e.g., ColBERT) use token_embed for scoring
is_late_interaction = model_config.is_late_interaction
if not is_late_interaction and all(
t not in supported_tasks for t in ("embed", "classify")
):
raise ValueError(
"Score API is not supported by this model. "
"Try converting the model using "
@@ -1538,6 +1624,15 @@ class LLM:
tokenization_kwargs=encode_kwargs,
score_template=chat_template,
)
elif is_late_interaction:
return self._late_interaction_score(
score_data_1,
score_data_2,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=encode_kwargs,
)
else:
return self._embedding_score(
score_data_1,