[Perf] Optimize compute maxsim using batched version, 3.2% E2E throughput improvement (#36710)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -4,13 +4,10 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
compute_maxsim_score,
|
||||
compute_maxsim_scores,
|
||||
get_score_prompt,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt
|
||||
@@ -354,36 +351,3 @@ class TestGetScorePrompt:
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
|
||||
def test_compute_maxsim_scores_matches_reference_per_pair() -> None:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(7)
|
||||
|
||||
shared_query = torch.randn(5, 8, generator=generator)
|
||||
q_embs = [
|
||||
shared_query, # 1:N style shared query
|
||||
shared_query,
|
||||
torch.randn(2, 8, generator=generator),
|
||||
torch.randn(4, 8, generator=generator),
|
||||
]
|
||||
d_embs = [
|
||||
torch.randn(6, 8, generator=generator),
|
||||
torch.randn(3, 8, generator=generator),
|
||||
torch.randn(5, 8, generator=generator),
|
||||
torch.randn(7, 8, generator=generator),
|
||||
]
|
||||
|
||||
batched_scores = compute_maxsim_scores(
|
||||
q_embs,
|
||||
d_embs,
|
||||
max_batch_size=4,
|
||||
max_score_matrix_elements=40, # batch shrinking path.
|
||||
)
|
||||
reference_scores = [
|
||||
compute_maxsim_score(q, d).to("cpu") for q, d in zip(q_embs, d_embs)
|
||||
]
|
||||
|
||||
assert len(batched_scores) == len(reference_scores)
|
||||
for batched, reference in zip(batched_scores, reference_scores):
|
||||
torch.testing.assert_close(batched, reference, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@@ -64,6 +64,47 @@ def test_postprocess_scores_and_releases_query_cache():
|
||||
)
|
||||
|
||||
|
||||
def test_postprocess_scores_docs_in_batch():
|
||||
runner = LateInteractionRunner()
|
||||
query_key = "query-batch"
|
||||
query_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
|
||||
doc_emb_1 = torch.tensor([[1.0, 0.0], [0.5, 0.5]], dtype=torch.float32)
|
||||
doc_emb_2 = torch.tensor([[0.0, 1.0], [0.3, 0.7], [1.0, 0.0]], dtype=torch.float32)
|
||||
|
||||
query_params = _make_pooling_params(
|
||||
build_late_interaction_query_params(query_key=query_key, query_uses=2)
|
||||
)
|
||||
runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[query_emb],
|
||||
pooling_params=[query_params],
|
||||
req_ids=["query-req"],
|
||||
finished_mask=[True],
|
||||
)
|
||||
|
||||
doc_params = _make_pooling_params(
|
||||
build_late_interaction_doc_params(query_key=query_key)
|
||||
)
|
||||
doc_output = runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[doc_emb_1, doc_emb_2],
|
||||
pooling_params=[doc_params, doc_params],
|
||||
req_ids=["doc-req-1", "doc-req-2"],
|
||||
finished_mask=[True, True],
|
||||
)
|
||||
assert isinstance(doc_output, list)
|
||||
assert doc_output[0] is not None
|
||||
assert doc_output[1] is not None
|
||||
assert torch.allclose(doc_output[0], compute_maxsim_score(query_emb, doc_emb_1))
|
||||
assert torch.allclose(doc_output[1], compute_maxsim_score(query_emb, doc_emb_2))
|
||||
|
||||
with pytest.raises(ValueError, match="query cache miss"):
|
||||
runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[doc_emb_1],
|
||||
pooling_params=[doc_params],
|
||||
req_ids=["doc-req-3"],
|
||||
finished_mask=[True],
|
||||
)
|
||||
|
||||
|
||||
def test_finished_request_releases_unscored_doc_use():
|
||||
runner = LateInteractionRunner()
|
||||
query_key = "query-cancel"
|
||||
|
||||
Reference in New Issue
Block a user