feat: Add ColBERT late interaction model support (#33686)
Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com> Signed-off-by: Ilya Boytsov <boytsovpanamera@mail.ru> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
get_score_prompt,
|
||||
validate_score_input,
|
||||
)
|
||||
@@ -1368,6 +1369,87 @@ class LLM:
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
pooling_params: PoolingParams | None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Extract text from ScoreData
|
||||
text_1: list[str] = []
|
||||
for text in data_1:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Late interaction scores currently do not support multimodal input."
|
||||
)
|
||||
text_1.append(text)
|
||||
|
||||
text_2: list[str] = []
|
||||
for text in data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Late interaction scores currently do not support multimodal input."
|
||||
)
|
||||
text_2.append(text)
|
||||
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
text_1 + text_2,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
pooling_task="token_embed",
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
|
||||
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
|
||||
|
||||
if len(encoded_output_1) == 1:
|
||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
@@ -1497,7 +1579,11 @@ class LLM:
|
||||
)
|
||||
|
||||
supported_tasks = self.supported_tasks
|
||||
if all(t not in supported_tasks for t in ("embed", "classify")):
|
||||
# Late interaction models (e.g., ColBERT) use token_embed for scoring
|
||||
is_late_interaction = model_config.is_late_interaction
|
||||
if not is_late_interaction and all(
|
||||
t not in supported_tasks for t in ("embed", "classify")
|
||||
):
|
||||
raise ValueError(
|
||||
"Score API is not supported by this model. "
|
||||
"Try converting the model using "
|
||||
@@ -1538,6 +1624,15 @@ class LLM:
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
score_template=chat_template,
|
||||
)
|
||||
elif is_late_interaction:
|
||||
return self._late_interaction_score(
|
||||
score_data_1,
|
||||
score_data_2,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
score_data_1,
|
||||
|
||||
Reference in New Issue
Block a user