[Perf] Optimize compute maxsim using batched version, 3.2% E2E throughput improvement (#36710)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-11 20:37:01 -04:00
committed by GitHub
parent 24062b704f
commit c34ba6b961
5 changed files with 141 additions and 127 deletions

View File

@@ -4,13 +4,10 @@
from unittest.mock import patch
import pytest
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
from vllm.entrypoints.pooling.score.utils import (
compute_maxsim_score,
compute_maxsim_scores,
get_score_prompt,
)
from vllm.inputs import TokensPrompt
@@ -354,36 +351,3 @@ class TestGetScorePrompt:
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt
)
def test_compute_maxsim_scores_matches_reference_per_pair() -> None:
generator = torch.Generator()
generator.manual_seed(7)
shared_query = torch.randn(5, 8, generator=generator)
q_embs = [
shared_query, # 1:N style shared query
shared_query,
torch.randn(2, 8, generator=generator),
torch.randn(4, 8, generator=generator),
]
d_embs = [
torch.randn(6, 8, generator=generator),
torch.randn(3, 8, generator=generator),
torch.randn(5, 8, generator=generator),
torch.randn(7, 8, generator=generator),
]
batched_scores = compute_maxsim_scores(
q_embs,
d_embs,
max_batch_size=4,
max_score_matrix_elements=40, # batch shrinking path.
)
reference_scores = [
compute_maxsim_score(q, d).to("cpu") for q, d in zip(q_embs, d_embs)
]
assert len(batched_scores) == len(reference_scores)
for batched, reference in zip(batched_scores, reference_scores):
torch.testing.assert_close(batched, reference, rtol=1e-4, atol=1e-4)