[Perf] Optimize compute maxsim using batched version, 3.2% E2E throughput improvement (#36710)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-11 20:37:01 -04:00
committed by GitHub
parent 24062b704f
commit c34ba6b961
5 changed files with 141 additions and 127 deletions

View File

@@ -4,13 +4,10 @@
from unittest.mock import patch
import pytest
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
from vllm.entrypoints.pooling.score.utils import (
compute_maxsim_score,
compute_maxsim_scores,
get_score_prompt,
)
from vllm.inputs import TokensPrompt
@@ -354,36 +351,3 @@ class TestGetScorePrompt:
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt
)
def test_compute_maxsim_scores_matches_reference_per_pair() -> None:
generator = torch.Generator()
generator.manual_seed(7)
shared_query = torch.randn(5, 8, generator=generator)
q_embs = [
shared_query, # 1:N style shared query
shared_query,
torch.randn(2, 8, generator=generator),
torch.randn(4, 8, generator=generator),
]
d_embs = [
torch.randn(6, 8, generator=generator),
torch.randn(3, 8, generator=generator),
torch.randn(5, 8, generator=generator),
torch.randn(7, 8, generator=generator),
]
batched_scores = compute_maxsim_scores(
q_embs,
d_embs,
max_batch_size=4,
max_score_matrix_elements=40, # batch shrinking path.
)
reference_scores = [
compute_maxsim_score(q, d).to("cpu") for q, d in zip(q_embs, d_embs)
]
assert len(batched_scores) == len(reference_scores)
for batched, reference in zip(batched_scores, reference_scores):
torch.testing.assert_close(batched, reference, rtol=1e-4, atol=1e-4)

View File

@@ -64,6 +64,47 @@ def test_postprocess_scores_and_releases_query_cache():
)
def test_postprocess_scores_docs_in_batch():
runner = LateInteractionRunner()
query_key = "query-batch"
query_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
doc_emb_1 = torch.tensor([[1.0, 0.0], [0.5, 0.5]], dtype=torch.float32)
doc_emb_2 = torch.tensor([[0.0, 1.0], [0.3, 0.7], [1.0, 0.0]], dtype=torch.float32)
query_params = _make_pooling_params(
build_late_interaction_query_params(query_key=query_key, query_uses=2)
)
runner.postprocess_pooler_output(
raw_pooler_output=[query_emb],
pooling_params=[query_params],
req_ids=["query-req"],
finished_mask=[True],
)
doc_params = _make_pooling_params(
build_late_interaction_doc_params(query_key=query_key)
)
doc_output = runner.postprocess_pooler_output(
raw_pooler_output=[doc_emb_1, doc_emb_2],
pooling_params=[doc_params, doc_params],
req_ids=["doc-req-1", "doc-req-2"],
finished_mask=[True, True],
)
assert isinstance(doc_output, list)
assert doc_output[0] is not None
assert doc_output[1] is not None
assert torch.allclose(doc_output[0], compute_maxsim_score(query_emb, doc_emb_1))
assert torch.allclose(doc_output[1], compute_maxsim_score(query_emb, doc_emb_2))
with pytest.raises(ValueError, match="query cache miss"):
runner.postprocess_pooler_output(
raw_pooler_output=[doc_emb_1],
pooling_params=[doc_params],
req_ids=["doc-req-3"],
finished_mask=[True],
)
def test_finished_request_releases_unscored_doc_use():
runner = LateInteractionRunner()
query_key = "query-cancel"

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from typing import Any, TypeAlias, cast
import torch
@@ -25,7 +25,6 @@ from vllm.inputs.data import PromptType, TextPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput
from vllm.platforms import current_platform
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
@@ -54,91 +53,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum()
def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
return use_gpu_for_pooling_score and not current_platform.is_cpu()
def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 16,
max_score_matrix_elements: int = 16_000_000,
use_gpu_for_pooling_score: bool = False,
) -> list[torch.Tensor]:
"""Compute ColBERT MaxSim scores in padded mini-batches."""
if len(q_embs) != len(d_embs):
raise ValueError("q_embs and d_embs must have the same length")
num_pairs = len(q_embs)
if num_pairs == 0:
return []
for q_emb, d_emb in zip(q_embs, d_embs):
if q_emb.ndim != 2 or d_emb.ndim != 2:
raise ValueError("Each embedding tensor must be 2-D")
if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")
compute_device = torch.device(
current_platform.device_type
if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
else "cpu"
)
scores: list[torch.Tensor] = []
start = 0
while start < num_pairs:
end = min(start + max_batch_size, num_pairs)
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
# keep score matrix bounded to avoid oversized allocations.
while (
end - start > 1
and (end - start) * max_q * max_d > max_score_matrix_elements
):
end -= 1
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
batch_q = q_embs[start:end]
batch_d = d_embs[start:end]
batch_size = end - start
dim = int(batch_q[0].shape[1])
dtype = batch_q[0].dtype
q_batch = torch.zeros(
(batch_size, max_q, dim), dtype=dtype, device=compute_device
)
d_batch = torch.zeros(
(batch_size, max_d, dim), dtype=dtype, device=compute_device
)
q_mask = torch.zeros(
(batch_size, max_q), dtype=torch.bool, device=compute_device
)
d_mask = torch.zeros(
(batch_size, max_d), dtype=torch.bool, device=compute_device
)
# copy to padded tensors
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
q_len = int(q_emb.shape[0])
d_len = int(d_emb.shape[0])
q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
q_mask[i, :q_len] = True
d_mask[i, :d_len] = True
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
max_per_query = token_scores.amax(dim=-1)
max_per_query.masked_fill_(~q_mask, 0)
batch_scores = max_per_query.sum(dim=-1).to("cpu")
scores.extend(batch_scores.unbind(0))
start = end
return [cast(torch.Tensor, score) for score in scores]
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import zlib
from collections.abc import Sequence
import torch
@@ -62,3 +63,81 @@ def compute_maxsim_score(
# compute in float32 for numerical stability
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
return token_scores.amax(dim=-1).sum()
def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 64,
max_score_matrix_elements: int = 64_000_000,
) -> list[torch.Tensor]:
"""Compute MaxSim for multiple query/doc pairs in mini-batches."""
if len(q_embs) != len(d_embs):
raise ValueError("q_embs and d_embs must have the same length")
num_pairs = len(q_embs)
if num_pairs == 0:
return []
if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")
if max_score_matrix_elements <= 0:
raise ValueError("max_score_matrix_elements must be greater than 0")
for q_emb, d_emb in zip(q_embs, d_embs):
if q_emb.ndim != 2 or d_emb.ndim != 2:
raise ValueError("Each embedding tensor must be 2-D")
if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")
if q_emb.device != d_emb.device:
raise ValueError("Query and document embeddings must be on same device")
scores: list[torch.Tensor] = []
start = 0
while start < num_pairs:
end = min(start + max_batch_size, num_pairs)
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
# keep score matrix bounded to avoid oversized allocations.
while (
end - start > 1
and (end - start) * max_q * max_d > max_score_matrix_elements
):
end -= 1
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
batch_q = q_embs[start:end]
batch_d = d_embs[start:end]
batch_size = end - start
device = batch_q[0].device
dim = int(batch_q[0].shape[1])
q_batch = torch.zeros(
(batch_size, max_q, dim), dtype=torch.float32, device=device
)
d_batch = torch.zeros(
(batch_size, max_d, dim), dtype=torch.float32, device=device
)
q_mask = torch.zeros((batch_size, max_q), dtype=torch.bool, device=device)
d_mask = torch.zeros((batch_size, max_d), dtype=torch.bool, device=device)
# copy to padded tensors
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
q_len = int(q_emb.shape[0])
d_len = int(d_emb.shape[0])
q_batch[i, :q_len] = q_emb.to(device=device, dtype=torch.float32)
d_batch[i, :d_len] = d_emb.to(device=device, dtype=torch.float32)
q_mask[i, :q_len] = True
d_mask[i, :d_len] = True
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
max_per_query = token_scores.amax(dim=-1)
max_per_query.masked_fill_(~q_mask, 0.0)
batch_scores = max_per_query.sum(dim=-1)
scores.extend(batch_scores.unbind(0))
start = end
return scores

View File

@@ -9,7 +9,7 @@ from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY,
LATE_INTERACTION_MODE_SCORE_DOC,
compute_maxsim_score,
compute_maxsim_scores,
)
@@ -72,6 +72,11 @@ class LateInteractionRunner:
return raw_pooler_output
outputs: list[torch.Tensor | None] = list(raw_pooler_output)
score_indices: list[int] = []
score_req_ids: list[str] = []
score_query_keys: list[str] = []
score_queries: list[torch.Tensor] = []
score_docs: list[torch.Tensor] = []
for i, (req_id, output, params, finished) in enumerate(
zip(req_ids, outputs, pooling_params, finished_mask)
):
@@ -101,13 +106,24 @@ class LateInteractionRunner:
"before their paired document requests."
)
outputs[i] = compute_maxsim_score(query_output, output)
self._doc_query_keys.pop(req_id, None)
self._release_query_use(query_key)
score_indices.append(i)
score_req_ids.append(req_id)
score_query_keys.append(query_key)
score_queries.append(query_output)
score_docs.append(output)
continue
raise ValueError(f"Unsupported late-interaction mode: {mode!r}")
if score_indices:
score_values = compute_maxsim_scores(score_queries, score_docs)
for i, req_id, query_key, score in zip(
score_indices, score_req_ids, score_query_keys, score_values
):
outputs[i] = score
self._doc_query_keys.pop(req_id, None)
self._release_query_use(query_key)
return outputs
def _release_query_use(self, query_key: str) -> None: