diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 9e3988b15..944fb88a0 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace): num_api_servers: int = args.api_server_count assert num_api_servers > 0 + if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False): + # TODO(wentao): remove this once well tested + raise ValueError( + "--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now" + ) + if num_api_servers > 1: setup_multiprocess_prometheus() diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 5655491fd..d3a66c183 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -278,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs): Enable offline FastAPI documentation for air-gapped environments. Uses vendored static assets bundled with vLLM. """ + use_gpu_for_pooling_score: bool = False + """If set, run pooling score MaxSim on GPU in the API server process. + Can significantly improve late-interaction scoring performance. + https://github.com/vllm-project/vllm/pull/35330""" @classmethod def _customize_cli_kwargs( diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index 1108be175..3ba131d5f 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -115,6 +115,7 @@ def init_pooling_state( request_logger=request_logger, score_template=resolved_chat_template, log_error_stack=args.log_error_stack, + use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False), ) if any(t in supported_tasks for t in ("embed", "score", "token_embed")) else None diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index aec6e909d..60d6db6a7 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -56,6 +56,7 @@ class ServingScores(OpenAIServing): request_logger: RequestLogger | None, score_template: str | None = None, log_error_stack: bool = False, + use_gpu_for_pooling_score: bool = False, ) -> None: super().__init__( engine_client=engine_client, @@ -64,6 +65,7 @@ class ServingScores(OpenAIServing): log_error_stack=log_error_stack, ) self.score_template = score_template + self.use_gpu_for_pooling_score = use_gpu_for_pooling_score self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) @@ -314,6 +316,7 @@ class ServingScores(OpenAIServing): maxsim_scores = compute_maxsim_scores( [emb.outputs.data for emb in emb_data_1], [emb.outputs.data for emb in emb_data_2], + use_gpu_for_pooling_score=self.use_gpu_for_pooling_score, ) scores: list[PoolingRequestOutput] = [] diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index 98c24856b..65611dc3a 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt from vllm.model_executor.models.interfaces import supports_score_template from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.outputs import PoolingRequestOutput +from vllm.platforms import current_platform from vllm.renderers.hf import safe_apply_chat_template from vllm.tokenizers import TokenizerLike @@ -53,11 +54,16 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens return token_scores.amax(dim=-1).sum() +def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool: + return use_gpu_for_pooling_score and not current_platform.is_cpu() + + def compute_maxsim_scores( q_embs: Sequence[torch.Tensor], d_embs: Sequence[torch.Tensor], max_batch_size: int = 16, max_score_matrix_elements: int = 16_000_000, + use_gpu_for_pooling_score: bool = False, ) -> list[torch.Tensor]: """Compute ColBERT MaxSim scores in padded mini-batches.""" if len(q_embs) != len(d_embs): @@ -73,7 +79,11 @@ def compute_maxsim_scores( if q_emb.shape[1] != d_emb.shape[1]: raise ValueError("Query and document embeddings must have same dim") - compute_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + compute_device = torch.device( + current_platform.device_type + if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score) + else "cpu" + ) scores: list[torch.Tensor] = [] start = 0 while start < num_pairs: