[Perf] Compute maxsim in worker side, reducing redundant copies, 2.7% E2E throughput improvement (#36159)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-09 23:55:58 -04:00
committed by GitHub
parent 006aea17d7
commit 7279374f91
11 changed files with 518 additions and 58 deletions

View File

@@ -24,17 +24,23 @@ from vllm import SamplingParams
from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublisher
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.pooling_params import LateInteractionParams, PoolingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (
AsyncMPClient,
DPLBAsyncMPClient,
EngineCoreClient,
SyncMPClient,
)
from vllm.v1.engine.utils import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY,
LATE_INTERACTION_MODE_SCORE_DOC,
)
from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test
@@ -164,6 +170,71 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch):
client.shutdown()
def _make_pooling_request(
request_id: str, *, mode: str | None = None, query_key: str | None = None
) -> EngineCoreRequest:
late_interaction_params = None
if mode is not None and query_key is not None:
late_interaction_params = LateInteractionParams(
mode=mode,
query_key=query_key,
)
return EngineCoreRequest(
request_id=request_id,
prompt_token_ids=[1, 2, 3],
mm_features=None,
sampling_params=None,
pooling_params=PoolingParams(
task="token_embed",
late_interaction_params=late_interaction_params,
),
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)
def test_dplb_late_interaction_sticky_routing():
client = object.__new__(DPLBAsyncMPClient)
client.client_count = 1
client.reqs_in_flight = {}
client.core_engines = [b"\x00\x00", b"\x01\x00", b"\x02\x00"]
client.lb_engines = [[0, 0], [0, 0], [0, 0]]
client.eng_start_index = 0
query_key = "rerank-abc-query-0"
query_request = _make_pooling_request(
"query-req", mode=LATE_INTERACTION_MODE_CACHE_QUERY, query_key=query_key
)
doc_request = _make_pooling_request(
"doc-req", mode=LATE_INTERACTION_MODE_SCORE_DOC, query_key=query_key
)
query_engine = client.get_core_engine_for_request(query_request)
doc_engine = client.get_core_engine_for_request(doc_request)
assert query_engine == doc_engine
assert client.reqs_in_flight["query-req"] == query_engine
assert client.reqs_in_flight["doc-req"] == doc_engine
def test_dplb_non_late_interaction_still_uses_lb():
client = object.__new__(DPLBAsyncMPClient)
client.client_count = 1
client.reqs_in_flight = {}
client.core_engines = [b"\x00\x00", b"\x01\x00", b"\x02\x00"]
client.lb_engines = [[2, 1], [0, 0], [1, 0]]
client.eng_start_index = 0
request = make_request(SamplingParams(max_tokens=1))
chosen_engine = client.get_core_engine_for_request(request)
assert chosen_engine == client.core_engines[1]
assert client.lb_engines[1][0] == 1
def loop_until_done(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = client.get_output().outputs

View File

@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.pooling_params import LateInteractionParams, PoolingParams
from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY,
build_late_interaction_doc_params,
build_late_interaction_query_params,
compute_maxsim_score,
)
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner
def _make_pooling_params(
late_interaction_params: LateInteractionParams,
) -> PoolingParams:
return PoolingParams(
task="token_embed",
late_interaction_params=late_interaction_params,
)
def test_postprocess_scores_and_releases_query_cache():
runner = LateInteractionRunner()
query_key = "query-0"
query_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
doc_emb = torch.tensor([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]], dtype=torch.float32)
query_params = _make_pooling_params(
build_late_interaction_query_params(query_key=query_key, query_uses=1)
)
query_output = runner.postprocess_pooler_output(
raw_pooler_output=[query_emb],
pooling_params=[query_params],
req_ids=["query-req"],
finished_mask=[True],
)
assert isinstance(query_output, list)
assert query_output[0] is not None
assert query_output[0].shape == torch.Size([])
doc_params = _make_pooling_params(
build_late_interaction_doc_params(query_key=query_key)
)
doc_output = runner.postprocess_pooler_output(
raw_pooler_output=[doc_emb],
pooling_params=[doc_params],
req_ids=["doc-req"],
finished_mask=[True],
)
assert isinstance(doc_output, list)
assert doc_output[0] is not None
assert torch.allclose(doc_output[0], compute_maxsim_score(query_emb, doc_emb))
with pytest.raises(ValueError, match="query cache miss"):
runner.postprocess_pooler_output(
raw_pooler_output=[doc_emb],
pooling_params=[doc_params],
req_ids=["doc-req-2"],
finished_mask=[True],
)
def test_finished_request_releases_unscored_doc_use():
runner = LateInteractionRunner()
query_key = "query-cancel"
query_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
doc_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
query_params = _make_pooling_params(
build_late_interaction_query_params(query_key=query_key, query_uses=1)
)
runner.postprocess_pooler_output(
raw_pooler_output=[query_emb],
pooling_params=[query_params],
req_ids=["query-req"],
finished_mask=[True],
)
doc_params = _make_pooling_params(
build_late_interaction_doc_params(query_key=query_key)
)
runner.register_request("doc-req", doc_params)
runner.on_requests_finished({"doc-req"})
with pytest.raises(ValueError, match="query cache miss"):
runner.postprocess_pooler_output(
raw_pooler_output=[doc_emb],
pooling_params=[doc_params],
req_ids=["doc-req-retry"],
finished_mask=[True],
)
def test_invalid_query_uses_raises():
runner = LateInteractionRunner()
bad_meta = LateInteractionParams(
mode=LATE_INTERACTION_MODE_CACHE_QUERY,
query_key="query-bad",
)
bad_meta.query_uses = "bad-int" # type: ignore[assignment]
bad_query_params = _make_pooling_params(bad_meta)
with pytest.raises(ValueError, match="must be an integer value"):
runner.postprocess_pooler_output(
raw_pooler_output=[torch.ones((2, 2), dtype=torch.float32)],
pooling_params=[bad_query_params],
req_ids=["query-req"],
finished_mask=[True],
)

View File

@@ -225,12 +225,6 @@ def run_multi_api_server(args: argparse.Namespace):
num_api_servers: int = args.api_server_count
assert num_api_servers > 0
if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False):
# TODO(wentao): remove this once well tested
raise ValueError(
"--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now"
)
if num_api_servers > 1:
setup_multiprocess_prometheus()

View File

@@ -281,10 +281,6 @@ class FrontendArgs(BaseFrontendArgs):
Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM.
"""
use_gpu_for_pooling_score: bool = False
"""If set, run pooling score MaxSim on GPU in the API server process.
Can significantly improve late-interaction scoring performance.
https://github.com/vllm-project/vllm/pull/35330"""
@classmethod
def _customize_cli_kwargs(

View File

@@ -111,7 +111,7 @@ def init_pooling_state(
state.openai_serving_models,
request_logger=request_logger,
score_template=resolved_chat_template,
use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
log_error_stack=args.log_error_stack,
)
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None

View File

@@ -31,7 +31,6 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
compute_maxsim_scores,
get_score_prompt,
parse_score_data_single,
validate_score_input,
@@ -43,6 +42,10 @@ from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import make_async, merge_async_iterators
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.v1.pool.late_interaction import (
build_late_interaction_doc_params,
build_late_interaction_query_params,
)
logger = init_logger(__name__)
@@ -56,7 +59,6 @@ class ServingScores(OpenAIServing):
request_logger: RequestLogger | None,
score_template: str | None = None,
log_error_stack: bool = False,
use_gpu_for_pooling_score: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
@@ -64,7 +66,6 @@ class ServingScores(OpenAIServing):
request_logger=request_logger,
)
self.score_template = score_template
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
@@ -253,19 +254,30 @@ class ServingScores(OpenAIServing):
)
)
input_texts: list[str] = []
engine_prompts: list[TokensPrompt] = []
for text, engine_prompt in preprocessed:
input_texts.append(text)
engine_prompts.append(engine_prompt)
query_prompts: list[TokensPrompt] = [
prompt for _, prompt in preprocessed[: len(data_1)]
]
doc_prompts: list[TokensPrompt] = [
prompt for _, prompt in preprocessed[len(data_1) :]
]
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params("token_embed")
pooling_params = request.to_pooling_params("token_embed")
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
# stage 1: encode queries and cache token embeddings on workers.
query_keys = [f"{request_id}-query-{i}" for i in range(len(query_prompts))]
query_uses = [len(doc_prompts) if len(query_prompts) == 1 else 1] * len(
query_prompts
)
query_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
for i, engine_prompt in enumerate(query_prompts):
request_id_item = f"{request_id}-query-{i}"
pooling_params = default_pooling_params.clone()
pooling_params.late_interaction_params = (
build_late_interaction_query_params(
query_key=query_keys[i],
query_uses=query_uses[i],
)
)
self._log_inputs(
request_id_item,
@@ -274,7 +286,7 @@ class ServingScores(OpenAIServing):
lora_request=lora_request,
)
generators.append(
query_generators.append(
self.engine_client.encode(
engine_prompt,
pooling_params,
@@ -285,53 +297,71 @@ class ServingScores(OpenAIServing):
)
)
result_generator = merge_async_iterators(*generators)
query_outputs: list[PoolingRequestOutput | None] = [None] * len(query_prompts)
if query_generators:
async for i, res in merge_async_iterators(*query_generators):
query_outputs[i] = res
# Collect token embeddings
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
assert all(res is not None for res in query_outputs)
query_results = [res for res in query_outputs if res is not None]
async for i, res in result_generator:
embeddings[i] = res
# Split into query and document embeddings
emb_data_1: list[PoolingRequestOutput] = []
emb_data_2: list[PoolingRequestOutput] = []
for i in range(0, len(data_1)):
assert (emb := embeddings[i]) is not None
emb_data_1.append(emb)
for i in range(len(data_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_data_2.append(emb)
# Expand queries if 1:N scoring
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
maxsim_scores = compute_maxsim_scores(
[emb.outputs.data for emb in emb_data_1],
[emb.outputs.data for emb in emb_data_2],
use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
# stage 2: encode docs and return scalar scores from workers.
doc_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
for i, engine_prompt in enumerate(doc_prompts):
request_id_item = f"{request_id}-doc-{i}"
query_idx = 0 if len(query_prompts) == 1 else i
pooling_params = default_pooling_params.clone()
pooling_params.late_interaction_params = build_late_interaction_doc_params(
query_key=query_keys[query_idx]
)
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=lora_request,
)
doc_generators.append(
self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
)
doc_outputs: list[PoolingRequestOutput | None] = [None] * len(doc_prompts)
if doc_generators:
async for i, res in merge_async_iterators(*doc_generators):
doc_outputs[i] = res
assert all(res is not None for res in doc_outputs)
doc_results = [res for res in doc_outputs if res is not None]
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
if len(query_results) == 1:
query_results = query_results * len(doc_results)
for query_result, doc_result in zip(query_results, doc_results):
tokens = (
query_result.prompt_token_ids + padding + doc_result.prompt_token_ids
)
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=PoolingOutput(data=maxsim_score),
request_id=f"{query_result.request_id}_{doc_result.request_id}",
outputs=doc_result.outputs,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
num_cached_tokens=(
query_result.num_cached_tokens + doc_result.num_cached_tokens
),
finished=True,
)
)

View File

@@ -11,6 +11,26 @@ from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask
class LateInteractionParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True,
): # type: ignore[call-arg]
"""Metadata for worker-side late-interaction scoring.
Attributes:
mode:
- "cache_query": cache query token embeddings
- "score_doc": score a document against a cached query.
query_key: stable key used for both DP routing and worker cache lookup.
query_uses: expected number of document requests
"""
mode: str
query_key: str
query_uses: int | None = None
class PoolingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
@@ -46,6 +66,7 @@ class PoolingParams(
task: PoolingTask | None = None
requires_token_ids: bool = False
skip_reading_prefix_cache: bool | None = None
late_interaction_params: LateInteractionParams | None = None
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@@ -193,6 +214,7 @@ class PoolingParams(
f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, "
f"skip_reading_prefix_cache={self.skip_reading_prefix_cache}, "
f"late_interaction_params={self.late_interaction_params}, "
f"extra_kwargs={self.extra_kwargs})"
)

View File

@@ -52,6 +52,7 @@ from vllm.v1.engine.utils import (
launch_core_engines,
)
from vllm.v1.executor import Executor
from vllm.v1.pool.late_interaction import get_late_interaction_engine_index
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
logger = init_logger(__name__)
@@ -1360,7 +1361,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest) -> EngineIdentity:
# Engines are in rank order.
if (eng_index := request.data_parallel_rank) is None:
if (eng_index := request.data_parallel_rank) is None and (
eng_index := get_late_interaction_engine_index(
request.pooling_params, len(self.core_engines)
)
) is None:
current_counts = self.lb_engines
# TODO use P2C alg for larger DP sizes
num_engines = len(current_counts)

View File

@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import zlib
import torch
from vllm.pooling_params import LateInteractionParams, PoolingParams
LATE_INTERACTION_MODE_CACHE_QUERY = "cache_query"
LATE_INTERACTION_MODE_SCORE_DOC = "score_doc"
def get_late_interaction_engine_index(
pooling_params: PoolingParams | None,
num_engines: int,
) -> int | None:
if pooling_params is None or pooling_params.late_interaction_params is None:
return None
late_interaction_params = pooling_params.late_interaction_params
mode = late_interaction_params.mode
if mode not in (
LATE_INTERACTION_MODE_CACHE_QUERY,
LATE_INTERACTION_MODE_SCORE_DOC,
):
return None
query_key = late_interaction_params.query_key
if not isinstance(query_key, str) or not query_key:
return None
# query embeddings are cached in process-local worker memory,
# pin requests sharing the same query key to the same engine.
return zlib.crc32(query_key.encode("utf-8")) % num_engines
def build_late_interaction_query_params(
query_key: str,
query_uses: int,
) -> LateInteractionParams:
return LateInteractionParams(
mode=LATE_INTERACTION_MODE_CACHE_QUERY,
query_key=query_key,
query_uses=max(1, int(query_uses)),
)
def build_late_interaction_doc_params(
query_key: str,
) -> LateInteractionParams:
return LateInteractionParams(
mode=LATE_INTERACTION_MODE_SCORE_DOC,
query_key=query_key,
)
def compute_maxsim_score(
q_emb: torch.Tensor,
d_emb: torch.Tensor,
) -> torch.Tensor:
# compute in float32 for numerical stability
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
return token_scores.amax(dim=-1).sum()

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
from vllm.pooling_params import PoolingParams
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY,
LATE_INTERACTION_MODE_SCORE_DOC,
compute_maxsim_score,
)
class LateInteractionRunner:
"""Worker-side state and postprocessing for late-interaction scoring."""
def __init__(self) -> None:
# query_key -> token embeddings for late-interaction scoring.
self._query_cache: dict[str, torch.Tensor] = {}
# query_key -> remaining number of docs that should use this query.
self._query_uses: dict[str, int] = {}
# doc request id -> query key.
self._doc_query_keys: dict[str, str] = {}
def clear(self) -> None:
self._query_cache.clear()
self._query_uses.clear()
self._doc_query_keys.clear()
def register_request(
self, req_id: str, pooling_params: PoolingParams | None
) -> None:
mode, query_key, _ = self._parse_late_interaction_meta(pooling_params)
if mode == LATE_INTERACTION_MODE_SCORE_DOC and query_key is not None:
self._doc_query_keys[req_id] = query_key
else:
self._doc_query_keys.pop(req_id, None)
def on_requests_finished(self, finished_req_ids: Iterable[str]) -> None:
for req_id in finished_req_ids:
query_key = self._doc_query_keys.pop(req_id, None)
if query_key is not None:
self._release_query_use(query_key)
def postprocess_pooler_output(
self,
raw_pooler_output: PoolerOutput,
pooling_params: list[PoolingParams],
req_ids: list[str],
finished_mask: list[bool],
) -> PoolerOutput:
if not isinstance(raw_pooler_output, list):
return raw_pooler_output
num_reqs = len(pooling_params)
if len(raw_pooler_output) != num_reqs:
raise ValueError(
"raw_pooler_output and pooling_params must have the same length."
)
if len(req_ids) != num_reqs:
raise ValueError("req_ids and pooling_params must have the same length.")
if len(finished_mask) != num_reqs:
raise ValueError(
"finished_mask and pooling_params must have the same length."
)
if not any(finished_mask):
return raw_pooler_output
if not any(p.late_interaction_params is not None for p in pooling_params):
return raw_pooler_output
outputs: list[torch.Tensor | None] = list(raw_pooler_output)
for i, (req_id, output, params, finished) in enumerate(
zip(req_ids, outputs, pooling_params, finished_mask)
):
if not finished or output is None:
continue
mode, query_key, query_uses = self._parse_late_interaction_meta(params)
if mode is None:
continue
assert query_key is not None
if mode == LATE_INTERACTION_MODE_CACHE_QUERY:
assert query_uses is not None
# `output` can be a view into the current step's hidden-states
# buffer, so clone it before storing across scheduling steps.
self._query_cache[query_key] = output.clone()
self._query_uses[query_key] = query_uses
outputs[i] = torch.zeros((), device=output.device, dtype=torch.float32)
continue
if mode == LATE_INTERACTION_MODE_SCORE_DOC:
query_output = self._query_cache.get(query_key)
if query_output is None:
raise ValueError(
"late-interaction query cache miss for key "
f"{query_key!r}. Ensure query requests are executed "
"before their paired document requests."
)
outputs[i] = compute_maxsim_score(query_output, output)
self._doc_query_keys.pop(req_id, None)
self._release_query_use(query_key)
continue
raise ValueError(f"Unsupported late-interaction mode: {mode!r}")
return outputs
def _release_query_use(self, query_key: str) -> None:
remaining = self._query_uses.get(query_key, 1) - 1
if remaining <= 0:
self._query_uses.pop(query_key, None)
self._query_cache.pop(query_key, None)
else:
self._query_uses[query_key] = remaining
@staticmethod
def _parse_late_interaction_meta(
pooling_params: PoolingParams | None,
) -> tuple[str | None, str | None, int | None]:
if pooling_params is None or pooling_params.late_interaction_params is None:
return None, None, None
late_interaction_params = pooling_params.late_interaction_params
mode = late_interaction_params.mode
query_key = late_interaction_params.query_key
if not isinstance(query_key, str) or not query_key:
raise ValueError(
"late-interaction request is missing a valid query key in "
"pooling_params.late_interaction_params."
)
if mode == LATE_INTERACTION_MODE_CACHE_QUERY:
query_uses_raw = late_interaction_params.query_uses
if query_uses_raw is None:
query_uses_raw = 1
try:
query_uses = max(1, int(query_uses_raw))
except (TypeError, ValueError) as exc:
raise ValueError(
"late-interaction query uses must be an integer value."
) from exc
return mode, query_key, query_uses
return mode, query_key, None

View File

@@ -181,6 +181,7 @@ from vllm.v1.worker.cp_utils import (
)
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@@ -491,6 +492,7 @@ class GPUModelRunner(
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
self.late_interaction_runner = LateInteractionRunner()
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
@@ -831,6 +833,7 @@ class GPUModelRunner(
"""
if self.mm_budget:
self.mm_budget.reset_cache()
self.late_interaction_runner.clear()
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
@@ -839,6 +842,7 @@ class GPUModelRunner(
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
self.late_interaction_runner.clear()
@torch.inference_mode()
def init_fp8_kv_scales(self) -> None:
@@ -1002,6 +1006,9 @@ class GPUModelRunner(
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.late_interaction_runner.on_requests_finished(
scheduler_output.finished_req_ids
)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -1089,6 +1096,7 @@ class GPUModelRunner(
lora_request=new_req_data.lora_request,
)
self.requests[req_id] = req_state
self.late_interaction_runner.register_request(req_id, pooling_params)
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
@@ -1360,6 +1368,7 @@ class GPUModelRunner(
req_state.prompt_embeds = new_req_data.prompt_embeds
req_state.sampling_params = new_req_data.sampling_params
req_state.pooling_params = new_req_data.pooling_params
self.late_interaction_runner.register_request(req_id, req_state.pooling_params)
req_state.block_ids = new_req_data.block_ids
req_state.num_computed_tokens = new_req_data.num_computed_tokens
req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
@@ -2875,6 +2884,12 @@ class GPUModelRunner(
seq_len == prompt_len
for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens)
]
raw_pooler_output = self.late_interaction_runner.postprocess_pooler_output(
raw_pooler_output=raw_pooler_output,
pooling_params=pooling_metadata.pooling_params,
req_ids=self.input_batch.req_ids,
finished_mask=finished_mask,
)
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids.copy(),