feat: Add ColBERT late interaction model support (#33686)
Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com> Signed-off-by: Ilya Boytsov <boytsovpanamera@mail.ru> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
get_score_prompt,
|
||||
validate_score_input,
|
||||
)
|
||||
@@ -1368,6 +1369,87 @@ class LLM:
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
pooling_params: PoolingParams | None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Extract text from ScoreData
|
||||
text_1: list[str] = []
|
||||
for text in data_1:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Late interaction scores currently do not support multimodal input."
|
||||
)
|
||||
text_1.append(text)
|
||||
|
||||
text_2: list[str] = []
|
||||
for text in data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Late interaction scores currently do not support multimodal input."
|
||||
)
|
||||
text_2.append(text)
|
||||
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
text_1 + text_2,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
pooling_task="token_embed",
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
|
||||
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
|
||||
|
||||
if len(encoded_output_1) == 1:
|
||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
@@ -1497,7 +1579,11 @@ class LLM:
|
||||
)
|
||||
|
||||
supported_tasks = self.supported_tasks
|
||||
if all(t not in supported_tasks for t in ("embed", "classify")):
|
||||
# Late interaction models (e.g., ColBERT) use token_embed for scoring
|
||||
is_late_interaction = model_config.is_late_interaction
|
||||
if not is_late_interaction and all(
|
||||
t not in supported_tasks for t in ("embed", "classify")
|
||||
):
|
||||
raise ValueError(
|
||||
"Score API is not supported by this model. "
|
||||
"Try converting the model using "
|
||||
@@ -1538,6 +1624,15 @@ class LLM:
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
score_template=chat_template,
|
||||
)
|
||||
elif is_late_interaction:
|
||||
return self._late_interaction_score(
|
||||
score_data_1,
|
||||
score_data_2,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
score_data_1,
|
||||
|
||||
@@ -37,7 +37,11 @@ def register_pooling_api_routers(
|
||||
|
||||
app.include_router(embed_router)
|
||||
|
||||
if "score" in supported_tasks or "embed" in supported_tasks:
|
||||
# Score/rerank endpoints are available for:
|
||||
# - "score" task (cross-encoder models)
|
||||
# - "embed" task (bi-encoder models)
|
||||
# - "token_embed" task (late interaction models like ColBERT)
|
||||
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
|
||||
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||
|
||||
app.include_router(score_router)
|
||||
@@ -101,6 +105,10 @@ def init_pooling_state(
|
||||
if "classify" in supported_tasks
|
||||
else None
|
||||
)
|
||||
# ServingScores handles score/rerank for:
|
||||
# - "score" task (cross-encoder models)
|
||||
# - "embed" task (bi-encoder models)
|
||||
# - "token_embed" task (late interaction models like ColBERT)
|
||||
state.openai_serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
@@ -109,6 +117,6 @@ def init_pooling_state(
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if ("embed" in supported_tasks or "score" in supported_tasks)
|
||||
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreInputs,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
get_score_prompt,
|
||||
validate_score_input,
|
||||
)
|
||||
@@ -68,9 +69,12 @@ class ServingScores(OpenAIServing):
|
||||
self.is_cross_encoder = self.model_config.is_cross_encoder
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.architecture = self.model_config.architecture
|
||||
self.is_late_interaction = self.model_config.is_late_interaction
|
||||
|
||||
if self.is_cross_encoder:
|
||||
self._score_func = self._cross_encoding_score
|
||||
elif self.is_late_interaction:
|
||||
self._score_func = self._late_interaction_score
|
||||
else:
|
||||
self._score_func = self._embedding_score
|
||||
|
||||
@@ -172,6 +176,142 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
return final_res_batch
|
||||
|
||||
async def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
input_texts: list[str] = []
|
||||
for text in data_1 + data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Late interaction scores currently do not support multimodal input."
|
||||
)
|
||||
input_texts.append(text)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
encode_async = make_async(
|
||||
tokenizer.encode,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(request, tok_result, input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
# Use token_embed task for late interaction models
|
||||
from vllm import PoolingParams
|
||||
|
||||
pooling_params = PoolingParams(
|
||||
task="token_embed",
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
use_activation=request.use_activation,
|
||||
)
|
||||
|
||||
try:
|
||||
pooling_params.verify("token_embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
input_texts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Collect token embeddings
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
# Split into query and document embeddings
|
||||
emb_data_1: list[PoolingRequestOutput] = []
|
||||
emb_data_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(data_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_1.append(emb)
|
||||
|
||||
for i in range(len(data_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_2.append(emb)
|
||||
|
||||
# Expand queries if 1:N scoring
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import torch
|
||||
from torch.nn import CosineSimilarity
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
@@ -34,6 +35,23 @@ ScoreContentPartParam: TypeAlias = (
|
||||
)
|
||||
|
||||
|
||||
def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute ColBERT MaxSim score.
|
||||
|
||||
Args:
|
||||
q_emb: Query token embeddings [query_len, dim]
|
||||
d_emb: Document token embeddings [doc_len, dim]
|
||||
|
||||
Returns:
|
||||
MaxSim score (sum over query tokens of max similarity to any doc token)
|
||||
"""
|
||||
# [query_len, doc_len]
|
||||
token_scores = torch.matmul(q_emb, d_emb.T)
|
||||
# Max over document tokens, sum over query tokens
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
|
||||
|
||||
class ScoreMultiModalParam(TypedDict, total=False):
|
||||
"""
|
||||
A specialized parameter type for scoring multimodal content
|
||||
|
||||
Reference in New Issue
Block a user