[Perf] Optimize maxsim scores computation for pooling models, 13.9% E2E throughput improvement (#35330)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-02-26 12:14:54 -05:00
committed by GitHub
parent ec8f943db1
commit 99c7892c5b
3 changed files with 123 additions and 11 deletions

View File

@@ -4,10 +4,15 @@
from unittest.mock import patch
import pytest
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
from vllm.entrypoints.pooling.score.utils import get_score_prompt
from vllm.entrypoints.pooling.score.utils import (
compute_maxsim_score,
compute_maxsim_scores,
get_score_prompt,
)
from vllm.inputs import TokensPrompt
from vllm.tokenizers import get_tokenizer
@@ -349,3 +354,36 @@ class TestGetScorePrompt:
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt
)
def test_compute_maxsim_scores_matches_reference_per_pair() -> None:
generator = torch.Generator()
generator.manual_seed(7)
shared_query = torch.randn(5, 8, generator=generator)
q_embs = [
shared_query, # 1:N style shared query
shared_query,
torch.randn(2, 8, generator=generator),
torch.randn(4, 8, generator=generator),
]
d_embs = [
torch.randn(6, 8, generator=generator),
torch.randn(3, 8, generator=generator),
torch.randn(5, 8, generator=generator),
torch.randn(7, 8, generator=generator),
]
batched_scores = compute_maxsim_scores(
q_embs,
d_embs,
max_batch_size=4,
max_score_matrix_elements=40, # batch shrinking path.
)
reference_scores = [
compute_maxsim_score(q, d).to("cpu") for q, d in zip(q_embs, d_embs)
]
assert len(batched_scores) == len(reference_scores)
for batched, reference in zip(batched_scores, reference_scores):
torch.testing.assert_close(batched, reference, rtol=1e-4, atol=1e-4)

View File

@@ -31,7 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_score,
compute_maxsim_scores,
get_score_prompt,
parse_score_data_single,
validate_score_input,
@@ -311,19 +311,17 @@ class ServingScores(OpenAIServing):
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
maxsim_scores = compute_maxsim_scores(
[emb.outputs.data for emb in emb_data_1],
[emb.outputs.data for emb in emb_data_2],
)
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from typing import Any, TypeAlias, cast
import torch
@@ -53,6 +53,82 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum()
def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 16,
max_score_matrix_elements: int = 16_000_000,
) -> list[torch.Tensor]:
"""Compute ColBERT MaxSim scores in padded mini-batches."""
if len(q_embs) != len(d_embs):
raise ValueError("q_embs and d_embs must have the same length")
num_pairs = len(q_embs)
if num_pairs == 0:
return []
for q_emb, d_emb in zip(q_embs, d_embs):
if q_emb.ndim != 2 or d_emb.ndim != 2:
raise ValueError("Each embedding tensor must be 2-D")
if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")
compute_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scores: list[torch.Tensor] = []
start = 0
while start < num_pairs:
end = min(start + max_batch_size, num_pairs)
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
# keep score matrix bounded to avoid oversized allocations.
while (
end - start > 1
and (end - start) * max_q * max_d > max_score_matrix_elements
):
end -= 1
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
batch_q = q_embs[start:end]
batch_d = d_embs[start:end]
batch_size = end - start
dim = int(batch_q[0].shape[1])
dtype = batch_q[0].dtype
q_batch = torch.zeros(
(batch_size, max_q, dim), dtype=dtype, device=compute_device
)
d_batch = torch.zeros(
(batch_size, max_d, dim), dtype=dtype, device=compute_device
)
q_mask = torch.zeros(
(batch_size, max_q), dtype=torch.bool, device=compute_device
)
d_mask = torch.zeros(
(batch_size, max_d), dtype=torch.bool, device=compute_device
)
# copy to padded tensors
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
q_len = int(q_emb.shape[0])
d_len = int(d_emb.shape[0])
q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
q_mask[i, :q_len] = True
d_mask[i, :d_len] = True
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
max_per_query = token_scores.amax(dim=-1)
max_per_query.masked_fill_(~q_mask, 0)
batch_scores = max_per_query.sum(dim=-1).to("cpu")
scores.extend(batch_scores.unbind(0))
start = end
return [cast(torch.Tensor, score) for score in scores]
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content