[Perf] Compute maxsim in worker side, reducing redundant copies, 2.7% E2E throughput improvement (#36159)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -52,6 +52,7 @@ from vllm.v1.engine.utils import (
|
||||
launch_core_engines,
|
||||
)
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.pool.late_interaction import get_late_interaction_engine_index
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -1360,7 +1361,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
|
||||
def get_core_engine_for_request(self, request: EngineCoreRequest) -> EngineIdentity:
|
||||
# Engines are in rank order.
|
||||
if (eng_index := request.data_parallel_rank) is None:
|
||||
if (eng_index := request.data_parallel_rank) is None and (
|
||||
eng_index := get_late_interaction_engine_index(
|
||||
request.pooling_params, len(self.core_engines)
|
||||
)
|
||||
) is None:
|
||||
current_counts = self.lb_engines
|
||||
# TODO use P2C alg for larger DP sizes
|
||||
num_engines = len(current_counts)
|
||||
|
||||
64
vllm/v1/pool/late_interaction.py
Normal file
64
vllm/v1/pool/late_interaction.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import zlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import LateInteractionParams, PoolingParams
|
||||
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY = "cache_query"
|
||||
LATE_INTERACTION_MODE_SCORE_DOC = "score_doc"
|
||||
|
||||
|
||||
def get_late_interaction_engine_index(
|
||||
pooling_params: PoolingParams | None,
|
||||
num_engines: int,
|
||||
) -> int | None:
|
||||
if pooling_params is None or pooling_params.late_interaction_params is None:
|
||||
return None
|
||||
|
||||
late_interaction_params = pooling_params.late_interaction_params
|
||||
mode = late_interaction_params.mode
|
||||
if mode not in (
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
LATE_INTERACTION_MODE_SCORE_DOC,
|
||||
):
|
||||
return None
|
||||
|
||||
query_key = late_interaction_params.query_key
|
||||
if not isinstance(query_key, str) or not query_key:
|
||||
return None
|
||||
|
||||
# query embeddings are cached in process-local worker memory,
|
||||
# pin requests sharing the same query key to the same engine.
|
||||
return zlib.crc32(query_key.encode("utf-8")) % num_engines
|
||||
|
||||
|
||||
def build_late_interaction_query_params(
|
||||
query_key: str,
|
||||
query_uses: int,
|
||||
) -> LateInteractionParams:
|
||||
return LateInteractionParams(
|
||||
mode=LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
query_key=query_key,
|
||||
query_uses=max(1, int(query_uses)),
|
||||
)
|
||||
|
||||
|
||||
def build_late_interaction_doc_params(
|
||||
query_key: str,
|
||||
) -> LateInteractionParams:
|
||||
return LateInteractionParams(
|
||||
mode=LATE_INTERACTION_MODE_SCORE_DOC,
|
||||
query_key=query_key,
|
||||
)
|
||||
|
||||
|
||||
def compute_maxsim_score(
|
||||
q_emb: torch.Tensor,
|
||||
d_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# compute in float32 for numerical stability
|
||||
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
150
vllm/v1/worker/gpu/pool/late_interaction_runner.py
Normal file
150
vllm/v1/worker/gpu/pool/late_interaction_runner.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
LATE_INTERACTION_MODE_SCORE_DOC,
|
||||
compute_maxsim_score,
|
||||
)
|
||||
|
||||
|
||||
class LateInteractionRunner:
|
||||
"""Worker-side state and postprocessing for late-interaction scoring."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# query_key -> token embeddings for late-interaction scoring.
|
||||
self._query_cache: dict[str, torch.Tensor] = {}
|
||||
# query_key -> remaining number of docs that should use this query.
|
||||
self._query_uses: dict[str, int] = {}
|
||||
# doc request id -> query key.
|
||||
self._doc_query_keys: dict[str, str] = {}
|
||||
|
||||
def clear(self) -> None:
|
||||
self._query_cache.clear()
|
||||
self._query_uses.clear()
|
||||
self._doc_query_keys.clear()
|
||||
|
||||
def register_request(
|
||||
self, req_id: str, pooling_params: PoolingParams | None
|
||||
) -> None:
|
||||
mode, query_key, _ = self._parse_late_interaction_meta(pooling_params)
|
||||
if mode == LATE_INTERACTION_MODE_SCORE_DOC and query_key is not None:
|
||||
self._doc_query_keys[req_id] = query_key
|
||||
else:
|
||||
self._doc_query_keys.pop(req_id, None)
|
||||
|
||||
def on_requests_finished(self, finished_req_ids: Iterable[str]) -> None:
|
||||
for req_id in finished_req_ids:
|
||||
query_key = self._doc_query_keys.pop(req_id, None)
|
||||
if query_key is not None:
|
||||
self._release_query_use(query_key)
|
||||
|
||||
def postprocess_pooler_output(
|
||||
self,
|
||||
raw_pooler_output: PoolerOutput,
|
||||
pooling_params: list[PoolingParams],
|
||||
req_ids: list[str],
|
||||
finished_mask: list[bool],
|
||||
) -> PoolerOutput:
|
||||
if not isinstance(raw_pooler_output, list):
|
||||
return raw_pooler_output
|
||||
|
||||
num_reqs = len(pooling_params)
|
||||
if len(raw_pooler_output) != num_reqs:
|
||||
raise ValueError(
|
||||
"raw_pooler_output and pooling_params must have the same length."
|
||||
)
|
||||
if len(req_ids) != num_reqs:
|
||||
raise ValueError("req_ids and pooling_params must have the same length.")
|
||||
if len(finished_mask) != num_reqs:
|
||||
raise ValueError(
|
||||
"finished_mask and pooling_params must have the same length."
|
||||
)
|
||||
|
||||
if not any(finished_mask):
|
||||
return raw_pooler_output
|
||||
if not any(p.late_interaction_params is not None for p in pooling_params):
|
||||
return raw_pooler_output
|
||||
|
||||
outputs: list[torch.Tensor | None] = list(raw_pooler_output)
|
||||
for i, (req_id, output, params, finished) in enumerate(
|
||||
zip(req_ids, outputs, pooling_params, finished_mask)
|
||||
):
|
||||
if not finished or output is None:
|
||||
continue
|
||||
|
||||
mode, query_key, query_uses = self._parse_late_interaction_meta(params)
|
||||
if mode is None:
|
||||
continue
|
||||
|
||||
assert query_key is not None
|
||||
if mode == LATE_INTERACTION_MODE_CACHE_QUERY:
|
||||
assert query_uses is not None
|
||||
# `output` can be a view into the current step's hidden-states
|
||||
# buffer, so clone it before storing across scheduling steps.
|
||||
self._query_cache[query_key] = output.clone()
|
||||
self._query_uses[query_key] = query_uses
|
||||
outputs[i] = torch.zeros((), device=output.device, dtype=torch.float32)
|
||||
continue
|
||||
|
||||
if mode == LATE_INTERACTION_MODE_SCORE_DOC:
|
||||
query_output = self._query_cache.get(query_key)
|
||||
if query_output is None:
|
||||
raise ValueError(
|
||||
"late-interaction query cache miss for key "
|
||||
f"{query_key!r}. Ensure query requests are executed "
|
||||
"before their paired document requests."
|
||||
)
|
||||
|
||||
outputs[i] = compute_maxsim_score(query_output, output)
|
||||
self._doc_query_keys.pop(req_id, None)
|
||||
self._release_query_use(query_key)
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unsupported late-interaction mode: {mode!r}")
|
||||
|
||||
return outputs
|
||||
|
||||
def _release_query_use(self, query_key: str) -> None:
|
||||
remaining = self._query_uses.get(query_key, 1) - 1
|
||||
if remaining <= 0:
|
||||
self._query_uses.pop(query_key, None)
|
||||
self._query_cache.pop(query_key, None)
|
||||
else:
|
||||
self._query_uses[query_key] = remaining
|
||||
|
||||
@staticmethod
|
||||
def _parse_late_interaction_meta(
|
||||
pooling_params: PoolingParams | None,
|
||||
) -> tuple[str | None, str | None, int | None]:
|
||||
if pooling_params is None or pooling_params.late_interaction_params is None:
|
||||
return None, None, None
|
||||
|
||||
late_interaction_params = pooling_params.late_interaction_params
|
||||
mode = late_interaction_params.mode
|
||||
|
||||
query_key = late_interaction_params.query_key
|
||||
if not isinstance(query_key, str) or not query_key:
|
||||
raise ValueError(
|
||||
"late-interaction request is missing a valid query key in "
|
||||
"pooling_params.late_interaction_params."
|
||||
)
|
||||
|
||||
if mode == LATE_INTERACTION_MODE_CACHE_QUERY:
|
||||
query_uses_raw = late_interaction_params.query_uses
|
||||
if query_uses_raw is None:
|
||||
query_uses_raw = 1
|
||||
try:
|
||||
query_uses = max(1, int(query_uses_raw))
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(
|
||||
"late-interaction query uses must be an integer value."
|
||||
) from exc
|
||||
return mode, query_key, query_uses
|
||||
|
||||
return mode, query_key, None
|
||||
@@ -181,6 +181,7 @@ from vllm.v1.worker.cp_utils import (
|
||||
)
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
@@ -491,6 +492,7 @@ class GPUModelRunner(
|
||||
|
||||
# mm_hash -> encoder_output
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
self.late_interaction_runner = LateInteractionRunner()
|
||||
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
# Set up speculative decoding.
|
||||
@@ -831,6 +833,7 @@ class GPUModelRunner(
|
||||
"""
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
self.late_interaction_runner.clear()
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
@@ -839,6 +842,7 @@ class GPUModelRunner(
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_cache.clear()
|
||||
self.late_interaction_runner.clear()
|
||||
|
||||
@torch.inference_mode()
|
||||
def init_fp8_kv_scales(self) -> None:
|
||||
@@ -1002,6 +1006,9 @@ class GPUModelRunner(
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.late_interaction_runner.on_requests_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
@@ -1089,6 +1096,7 @@ class GPUModelRunner(
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
self.requests[req_id] = req_state
|
||||
self.late_interaction_runner.register_request(req_id, pooling_params)
|
||||
|
||||
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
@@ -1360,6 +1368,7 @@ class GPUModelRunner(
|
||||
req_state.prompt_embeds = new_req_data.prompt_embeds
|
||||
req_state.sampling_params = new_req_data.sampling_params
|
||||
req_state.pooling_params = new_req_data.pooling_params
|
||||
self.late_interaction_runner.register_request(req_id, req_state.pooling_params)
|
||||
req_state.block_ids = new_req_data.block_ids
|
||||
req_state.num_computed_tokens = new_req_data.num_computed_tokens
|
||||
req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
@@ -2875,6 +2884,12 @@ class GPUModelRunner(
|
||||
seq_len == prompt_len
|
||||
for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens)
|
||||
]
|
||||
raw_pooler_output = self.late_interaction_runner.postprocess_pooler_output(
|
||||
raw_pooler_output=raw_pooler_output,
|
||||
pooling_params=pooling_metadata.pooling_params,
|
||||
req_ids=self.input_batch.req_ids,
|
||||
finished_mask=finished_mask,
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids.copy(),
|
||||
|
||||
Reference in New Issue
Block a user