[Perf] Compute maxsim in worker side, reducing redundant copies, 2.7% E2E throughput improvement (#36159)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-09 23:55:58 -04:00
committed by GitHub
parent 006aea17d7
commit 7279374f91
11 changed files with 518 additions and 58 deletions

View File

@@ -52,6 +52,7 @@ from vllm.v1.engine.utils import (
launch_core_engines,
)
from vllm.v1.executor import Executor
from vllm.v1.pool.late_interaction import get_late_interaction_engine_index
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
logger = init_logger(__name__)
@@ -1360,7 +1361,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest) -> EngineIdentity:
# Engines are in rank order.
if (eng_index := request.data_parallel_rank) is None:
if (eng_index := request.data_parallel_rank) is None and (
eng_index := get_late_interaction_engine_index(
request.pooling_params, len(self.core_engines)
)
) is None:
current_counts = self.lb_engines
# TODO use P2C alg for larger DP sizes
num_engines = len(current_counts)

View File

@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import zlib
import torch
from vllm.pooling_params import LateInteractionParams, PoolingParams
LATE_INTERACTION_MODE_CACHE_QUERY = "cache_query"
LATE_INTERACTION_MODE_SCORE_DOC = "score_doc"
def get_late_interaction_engine_index(
pooling_params: PoolingParams | None,
num_engines: int,
) -> int | None:
if pooling_params is None or pooling_params.late_interaction_params is None:
return None
late_interaction_params = pooling_params.late_interaction_params
mode = late_interaction_params.mode
if mode not in (
LATE_INTERACTION_MODE_CACHE_QUERY,
LATE_INTERACTION_MODE_SCORE_DOC,
):
return None
query_key = late_interaction_params.query_key
if not isinstance(query_key, str) or not query_key:
return None
# query embeddings are cached in process-local worker memory,
# pin requests sharing the same query key to the same engine.
return zlib.crc32(query_key.encode("utf-8")) % num_engines
def build_late_interaction_query_params(
query_key: str,
query_uses: int,
) -> LateInteractionParams:
return LateInteractionParams(
mode=LATE_INTERACTION_MODE_CACHE_QUERY,
query_key=query_key,
query_uses=max(1, int(query_uses)),
)
def build_late_interaction_doc_params(
query_key: str,
) -> LateInteractionParams:
return LateInteractionParams(
mode=LATE_INTERACTION_MODE_SCORE_DOC,
query_key=query_key,
)
def compute_maxsim_score(
q_emb: torch.Tensor,
d_emb: torch.Tensor,
) -> torch.Tensor:
# compute in float32 for numerical stability
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
return token_scores.amax(dim=-1).sum()

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
from vllm.pooling_params import PoolingParams
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY,
LATE_INTERACTION_MODE_SCORE_DOC,
compute_maxsim_score,
)
class LateInteractionRunner:
"""Worker-side state and postprocessing for late-interaction scoring."""
def __init__(self) -> None:
# query_key -> token embeddings for late-interaction scoring.
self._query_cache: dict[str, torch.Tensor] = {}
# query_key -> remaining number of docs that should use this query.
self._query_uses: dict[str, int] = {}
# doc request id -> query key.
self._doc_query_keys: dict[str, str] = {}
def clear(self) -> None:
self._query_cache.clear()
self._query_uses.clear()
self._doc_query_keys.clear()
def register_request(
self, req_id: str, pooling_params: PoolingParams | None
) -> None:
mode, query_key, _ = self._parse_late_interaction_meta(pooling_params)
if mode == LATE_INTERACTION_MODE_SCORE_DOC and query_key is not None:
self._doc_query_keys[req_id] = query_key
else:
self._doc_query_keys.pop(req_id, None)
def on_requests_finished(self, finished_req_ids: Iterable[str]) -> None:
for req_id in finished_req_ids:
query_key = self._doc_query_keys.pop(req_id, None)
if query_key is not None:
self._release_query_use(query_key)
def postprocess_pooler_output(
self,
raw_pooler_output: PoolerOutput,
pooling_params: list[PoolingParams],
req_ids: list[str],
finished_mask: list[bool],
) -> PoolerOutput:
if not isinstance(raw_pooler_output, list):
return raw_pooler_output
num_reqs = len(pooling_params)
if len(raw_pooler_output) != num_reqs:
raise ValueError(
"raw_pooler_output and pooling_params must have the same length."
)
if len(req_ids) != num_reqs:
raise ValueError("req_ids and pooling_params must have the same length.")
if len(finished_mask) != num_reqs:
raise ValueError(
"finished_mask and pooling_params must have the same length."
)
if not any(finished_mask):
return raw_pooler_output
if not any(p.late_interaction_params is not None for p in pooling_params):
return raw_pooler_output
outputs: list[torch.Tensor | None] = list(raw_pooler_output)
for i, (req_id, output, params, finished) in enumerate(
zip(req_ids, outputs, pooling_params, finished_mask)
):
if not finished or output is None:
continue
mode, query_key, query_uses = self._parse_late_interaction_meta(params)
if mode is None:
continue
assert query_key is not None
if mode == LATE_INTERACTION_MODE_CACHE_QUERY:
assert query_uses is not None
# `output` can be a view into the current step's hidden-states
# buffer, so clone it before storing across scheduling steps.
self._query_cache[query_key] = output.clone()
self._query_uses[query_key] = query_uses
outputs[i] = torch.zeros((), device=output.device, dtype=torch.float32)
continue
if mode == LATE_INTERACTION_MODE_SCORE_DOC:
query_output = self._query_cache.get(query_key)
if query_output is None:
raise ValueError(
"late-interaction query cache miss for key "
f"{query_key!r}. Ensure query requests are executed "
"before their paired document requests."
)
outputs[i] = compute_maxsim_score(query_output, output)
self._doc_query_keys.pop(req_id, None)
self._release_query_use(query_key)
continue
raise ValueError(f"Unsupported late-interaction mode: {mode!r}")
return outputs
def _release_query_use(self, query_key: str) -> None:
remaining = self._query_uses.get(query_key, 1) - 1
if remaining <= 0:
self._query_uses.pop(query_key, None)
self._query_cache.pop(query_key, None)
else:
self._query_uses[query_key] = remaining
@staticmethod
def _parse_late_interaction_meta(
pooling_params: PoolingParams | None,
) -> tuple[str | None, str | None, int | None]:
if pooling_params is None or pooling_params.late_interaction_params is None:
return None, None, None
late_interaction_params = pooling_params.late_interaction_params
mode = late_interaction_params.mode
query_key = late_interaction_params.query_key
if not isinstance(query_key, str) or not query_key:
raise ValueError(
"late-interaction request is missing a valid query key in "
"pooling_params.late_interaction_params."
)
if mode == LATE_INTERACTION_MODE_CACHE_QUERY:
query_uses_raw = late_interaction_params.query_uses
if query_uses_raw is None:
query_uses_raw = 1
try:
query_uses = max(1, int(query_uses_raw))
except (TypeError, ValueError) as exc:
raise ValueError(
"late-interaction query uses must be an integer value."
) from exc
return mode, query_key, query_uses
return mode, query_key, None

View File

@@ -181,6 +181,7 @@ from vllm.v1.worker.cp_utils import (
)
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@@ -491,6 +492,7 @@ class GPUModelRunner(
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
self.late_interaction_runner = LateInteractionRunner()
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
@@ -831,6 +833,7 @@ class GPUModelRunner(
"""
if self.mm_budget:
self.mm_budget.reset_cache()
self.late_interaction_runner.clear()
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
@@ -839,6 +842,7 @@ class GPUModelRunner(
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
self.late_interaction_runner.clear()
@torch.inference_mode()
def init_fp8_kv_scales(self) -> None:
@@ -1002,6 +1006,9 @@ class GPUModelRunner(
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.late_interaction_runner.on_requests_finished(
scheduler_output.finished_req_ids
)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -1089,6 +1096,7 @@ class GPUModelRunner(
lora_request=new_req_data.lora_request,
)
self.requests[req_id] = req_state
self.late_interaction_runner.register_request(req_id, pooling_params)
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
@@ -1360,6 +1368,7 @@ class GPUModelRunner(
req_state.prompt_embeds = new_req_data.prompt_embeds
req_state.sampling_params = new_req_data.sampling_params
req_state.pooling_params = new_req_data.pooling_params
self.late_interaction_runner.register_request(req_id, req_state.pooling_params)
req_state.block_ids = new_req_data.block_ids
req_state.num_computed_tokens = new_req_data.num_computed_tokens
req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
@@ -2875,6 +2884,12 @@ class GPUModelRunner(
seq_len == prompt_len
for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens)
]
raw_pooler_output = self.late_interaction_runner.postprocess_pooler_output(
raw_pooler_output=raw_pooler_output,
pooling_params=pooling_metadata.pooling_params,
req_ids=self.input_batch.req_ids,
finished_mask=finished_mask,
)
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids.copy(),