diff --git a/tests/entrypoints/pooling/score/test_utils.py b/tests/entrypoints/pooling/score/test_utils.py index e5e1fd606..20b6df4a9 100644 --- a/tests/entrypoints/pooling/score/test_utils.py +++ b/tests/entrypoints/pooling/score/test_utils.py @@ -4,13 +4,10 @@ from unittest.mock import patch import pytest -import torch from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ChatTemplateResolutionError from vllm.entrypoints.pooling.score.utils import ( - compute_maxsim_score, - compute_maxsim_scores, get_score_prompt, ) from vllm.inputs import TokensPrompt @@ -354,36 +351,3 @@ class TestGetScorePrompt: assert_prompt_tokenization_consistent( cross_encoder_tokenizer, full_prompt, engine_prompt ) - - -def test_compute_maxsim_scores_matches_reference_per_pair() -> None: - generator = torch.Generator() - generator.manual_seed(7) - - shared_query = torch.randn(5, 8, generator=generator) - q_embs = [ - shared_query, # 1:N style shared query - shared_query, - torch.randn(2, 8, generator=generator), - torch.randn(4, 8, generator=generator), - ] - d_embs = [ - torch.randn(6, 8, generator=generator), - torch.randn(3, 8, generator=generator), - torch.randn(5, 8, generator=generator), - torch.randn(7, 8, generator=generator), - ] - - batched_scores = compute_maxsim_scores( - q_embs, - d_embs, - max_batch_size=4, - max_score_matrix_elements=40, # batch shrinking path. - ) - reference_scores = [ - compute_maxsim_score(q, d).to("cpu") for q, d in zip(q_embs, d_embs) - ] - - assert len(batched_scores) == len(reference_scores) - for batched, reference in zip(batched_scores, reference_scores): - torch.testing.assert_close(batched, reference, rtol=1e-4, atol=1e-4) diff --git a/tests/v1/worker/test_late_interaction_runner.py b/tests/v1/worker/test_late_interaction_runner.py index 00a54a9e1..5be3f6e6f 100644 --- a/tests/v1/worker/test_late_interaction_runner.py +++ b/tests/v1/worker/test_late_interaction_runner.py @@ -64,6 +64,47 @@ def test_postprocess_scores_and_releases_query_cache(): ) +def test_postprocess_scores_docs_in_batch(): + runner = LateInteractionRunner() + query_key = "query-batch" + query_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + doc_emb_1 = torch.tensor([[1.0, 0.0], [0.5, 0.5]], dtype=torch.float32) + doc_emb_2 = torch.tensor([[0.0, 1.0], [0.3, 0.7], [1.0, 0.0]], dtype=torch.float32) + + query_params = _make_pooling_params( + build_late_interaction_query_params(query_key=query_key, query_uses=2) + ) + runner.postprocess_pooler_output( + raw_pooler_output=[query_emb], + pooling_params=[query_params], + req_ids=["query-req"], + finished_mask=[True], + ) + + doc_params = _make_pooling_params( + build_late_interaction_doc_params(query_key=query_key) + ) + doc_output = runner.postprocess_pooler_output( + raw_pooler_output=[doc_emb_1, doc_emb_2], + pooling_params=[doc_params, doc_params], + req_ids=["doc-req-1", "doc-req-2"], + finished_mask=[True, True], + ) + assert isinstance(doc_output, list) + assert doc_output[0] is not None + assert doc_output[1] is not None + assert torch.allclose(doc_output[0], compute_maxsim_score(query_emb, doc_emb_1)) + assert torch.allclose(doc_output[1], compute_maxsim_score(query_emb, doc_emb_2)) + + with pytest.raises(ValueError, match="query cache miss"): + runner.postprocess_pooler_output( + raw_pooler_output=[doc_emb_1], + pooling_params=[doc_params], + req_ids=["doc-req-3"], + finished_mask=[True], + ) + + def test_finished_request_releases_unscored_doc_use(): runner = LateInteractionRunner() query_key = "query-cancel" diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index 65611dc3a..60e71ff73 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from typing import Any, TypeAlias, cast import torch @@ -25,7 +25,6 @@ from vllm.inputs.data import PromptType, TextPrompt from vllm.model_executor.models.interfaces import supports_score_template from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.outputs import PoolingRequestOutput -from vllm.platforms import current_platform from vllm.renderers.hf import safe_apply_chat_template from vllm.tokenizers import TokenizerLike @@ -54,91 +53,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens return token_scores.amax(dim=-1).sum() -def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool: - return use_gpu_for_pooling_score and not current_platform.is_cpu() - - -def compute_maxsim_scores( - q_embs: Sequence[torch.Tensor], - d_embs: Sequence[torch.Tensor], - max_batch_size: int = 16, - max_score_matrix_elements: int = 16_000_000, - use_gpu_for_pooling_score: bool = False, -) -> list[torch.Tensor]: - """Compute ColBERT MaxSim scores in padded mini-batches.""" - if len(q_embs) != len(d_embs): - raise ValueError("q_embs and d_embs must have the same length") - - num_pairs = len(q_embs) - if num_pairs == 0: - return [] - - for q_emb, d_emb in zip(q_embs, d_embs): - if q_emb.ndim != 2 or d_emb.ndim != 2: - raise ValueError("Each embedding tensor must be 2-D") - if q_emb.shape[1] != d_emb.shape[1]: - raise ValueError("Query and document embeddings must have same dim") - - compute_device = torch.device( - current_platform.device_type - if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score) - else "cpu" - ) - scores: list[torch.Tensor] = [] - start = 0 - while start < num_pairs: - end = min(start + max_batch_size, num_pairs) - max_q = max(int(x.shape[0]) for x in q_embs[start:end]) - max_d = max(int(x.shape[0]) for x in d_embs[start:end]) - - # keep score matrix bounded to avoid oversized allocations. - while ( - end - start > 1 - and (end - start) * max_q * max_d > max_score_matrix_elements - ): - end -= 1 - max_q = max(int(x.shape[0]) for x in q_embs[start:end]) - max_d = max(int(x.shape[0]) for x in d_embs[start:end]) - - batch_q = q_embs[start:end] - batch_d = d_embs[start:end] - batch_size = end - start - dim = int(batch_q[0].shape[1]) - dtype = batch_q[0].dtype - - q_batch = torch.zeros( - (batch_size, max_q, dim), dtype=dtype, device=compute_device - ) - d_batch = torch.zeros( - (batch_size, max_d, dim), dtype=dtype, device=compute_device - ) - q_mask = torch.zeros( - (batch_size, max_q), dtype=torch.bool, device=compute_device - ) - d_mask = torch.zeros( - (batch_size, max_d), dtype=torch.bool, device=compute_device - ) - - # copy to padded tensors - for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)): - q_len = int(q_emb.shape[0]) - d_len = int(d_emb.shape[0]) - q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype) - d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype) - q_mask[i, :q_len] = True - d_mask[i, :d_len] = True - - token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2)) - token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf")) - max_per_query = token_scores.amax(dim=-1) - max_per_query.masked_fill_(~q_mask, 0) - batch_scores = max_per_query.sum(dim=-1).to("cpu") - scores.extend(batch_scores.unbind(0)) - start = end - - return [cast(torch.Tensor, score) for score in scores] - - class ScoreMultiModalParam(TypedDict, total=False): """ A specialized parameter type for scoring multimodal content diff --git a/vllm/v1/pool/late_interaction.py b/vllm/v1/pool/late_interaction.py index dc21528c2..4a465bd2f 100644 --- a/vllm/v1/pool/late_interaction.py +++ b/vllm/v1/pool/late_interaction.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import zlib +from collections.abc import Sequence import torch @@ -62,3 +63,81 @@ def compute_maxsim_score( # compute in float32 for numerical stability token_scores = torch.matmul(q_emb.float(), d_emb.float().T) return token_scores.amax(dim=-1).sum() + + +def compute_maxsim_scores( + q_embs: Sequence[torch.Tensor], + d_embs: Sequence[torch.Tensor], + max_batch_size: int = 64, + max_score_matrix_elements: int = 64_000_000, +) -> list[torch.Tensor]: + """Compute MaxSim for multiple query/doc pairs in mini-batches.""" + if len(q_embs) != len(d_embs): + raise ValueError("q_embs and d_embs must have the same length") + + num_pairs = len(q_embs) + if num_pairs == 0: + return [] + + if max_batch_size <= 0: + raise ValueError("max_batch_size must be greater than 0") + if max_score_matrix_elements <= 0: + raise ValueError("max_score_matrix_elements must be greater than 0") + + for q_emb, d_emb in zip(q_embs, d_embs): + if q_emb.ndim != 2 or d_emb.ndim != 2: + raise ValueError("Each embedding tensor must be 2-D") + if q_emb.shape[1] != d_emb.shape[1]: + raise ValueError("Query and document embeddings must have same dim") + if q_emb.device != d_emb.device: + raise ValueError("Query and document embeddings must be on same device") + + scores: list[torch.Tensor] = [] + start = 0 + while start < num_pairs: + end = min(start + max_batch_size, num_pairs) + max_q = max(int(x.shape[0]) for x in q_embs[start:end]) + max_d = max(int(x.shape[0]) for x in d_embs[start:end]) + + # keep score matrix bounded to avoid oversized allocations. + while ( + end - start > 1 + and (end - start) * max_q * max_d > max_score_matrix_elements + ): + end -= 1 + max_q = max(int(x.shape[0]) for x in q_embs[start:end]) + max_d = max(int(x.shape[0]) for x in d_embs[start:end]) + + batch_q = q_embs[start:end] + batch_d = d_embs[start:end] + batch_size = end - start + device = batch_q[0].device + dim = int(batch_q[0].shape[1]) + + q_batch = torch.zeros( + (batch_size, max_q, dim), dtype=torch.float32, device=device + ) + d_batch = torch.zeros( + (batch_size, max_d, dim), dtype=torch.float32, device=device + ) + q_mask = torch.zeros((batch_size, max_q), dtype=torch.bool, device=device) + d_mask = torch.zeros((batch_size, max_d), dtype=torch.bool, device=device) + + # copy to padded tensors + for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)): + q_len = int(q_emb.shape[0]) + d_len = int(d_emb.shape[0]) + q_batch[i, :q_len] = q_emb.to(device=device, dtype=torch.float32) + d_batch[i, :d_len] = d_emb.to(device=device, dtype=torch.float32) + q_mask[i, :q_len] = True + d_mask[i, :d_len] = True + + token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2)) + token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf")) + max_per_query = token_scores.amax(dim=-1) + max_per_query.masked_fill_(~q_mask, 0.0) + batch_scores = max_per_query.sum(dim=-1) + scores.extend(batch_scores.unbind(0)) + start = end + + return scores diff --git a/vllm/v1/worker/gpu/pool/late_interaction_runner.py b/vllm/v1/worker/gpu/pool/late_interaction_runner.py index 3ad00bc7c..221dee558 100644 --- a/vllm/v1/worker/gpu/pool/late_interaction_runner.py +++ b/vllm/v1/worker/gpu/pool/late_interaction_runner.py @@ -9,7 +9,7 @@ from vllm.v1.outputs import PoolerOutput from vllm.v1.pool.late_interaction import ( LATE_INTERACTION_MODE_CACHE_QUERY, LATE_INTERACTION_MODE_SCORE_DOC, - compute_maxsim_score, + compute_maxsim_scores, ) @@ -72,6 +72,11 @@ class LateInteractionRunner: return raw_pooler_output outputs: list[torch.Tensor | None] = list(raw_pooler_output) + score_indices: list[int] = [] + score_req_ids: list[str] = [] + score_query_keys: list[str] = [] + score_queries: list[torch.Tensor] = [] + score_docs: list[torch.Tensor] = [] for i, (req_id, output, params, finished) in enumerate( zip(req_ids, outputs, pooling_params, finished_mask) ): @@ -101,13 +106,24 @@ class LateInteractionRunner: "before their paired document requests." ) - outputs[i] = compute_maxsim_score(query_output, output) - self._doc_query_keys.pop(req_id, None) - self._release_query_use(query_key) + score_indices.append(i) + score_req_ids.append(req_id) + score_query_keys.append(query_key) + score_queries.append(query_output) + score_docs.append(output) continue raise ValueError(f"Unsupported late-interaction mode: {mode!r}") + if score_indices: + score_values = compute_maxsim_scores(score_queries, score_docs) + for i, req_id, query_key, score in zip( + score_indices, score_req_ids, score_query_keys, score_values + ): + outputs[i] = score + self._doc_query_keys.pop(req_id, None) + self._release_query_use(query_key) + return outputs def _release_query_use(self, query_key: str) -> None: