[Perf] Compute maxsim in worker side, reducing redundant copies, 2.7% E2E throughput improvement (#36159)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -24,17 +24,23 @@ from vllm import SamplingParams
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublisher
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import LateInteractionParams, PoolingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import (
|
||||
AsyncMPClient,
|
||||
DPLBAsyncMPClient,
|
||||
EngineCoreClient,
|
||||
SyncMPClient,
|
||||
)
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
LATE_INTERACTION_MODE_SCORE_DOC,
|
||||
)
|
||||
|
||||
from ...distributed.conftest import MockSubscriber
|
||||
from ...utils import create_new_process_for_each_test
|
||||
@@ -164,6 +170,71 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||
client.shutdown()
|
||||
|
||||
|
||||
def _make_pooling_request(
|
||||
request_id: str, *, mode: str | None = None, query_key: str | None = None
|
||||
) -> EngineCoreRequest:
|
||||
late_interaction_params = None
|
||||
if mode is not None and query_key is not None:
|
||||
late_interaction_params = LateInteractionParams(
|
||||
mode=mode,
|
||||
query_key=query_key,
|
||||
)
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
mm_features=None,
|
||||
sampling_params=None,
|
||||
pooling_params=PoolingParams(
|
||||
task="token_embed",
|
||||
late_interaction_params=late_interaction_params,
|
||||
),
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
|
||||
def test_dplb_late_interaction_sticky_routing():
|
||||
client = object.__new__(DPLBAsyncMPClient)
|
||||
client.client_count = 1
|
||||
client.reqs_in_flight = {}
|
||||
client.core_engines = [b"\x00\x00", b"\x01\x00", b"\x02\x00"]
|
||||
client.lb_engines = [[0, 0], [0, 0], [0, 0]]
|
||||
client.eng_start_index = 0
|
||||
|
||||
query_key = "rerank-abc-query-0"
|
||||
query_request = _make_pooling_request(
|
||||
"query-req", mode=LATE_INTERACTION_MODE_CACHE_QUERY, query_key=query_key
|
||||
)
|
||||
doc_request = _make_pooling_request(
|
||||
"doc-req", mode=LATE_INTERACTION_MODE_SCORE_DOC, query_key=query_key
|
||||
)
|
||||
|
||||
query_engine = client.get_core_engine_for_request(query_request)
|
||||
doc_engine = client.get_core_engine_for_request(doc_request)
|
||||
|
||||
assert query_engine == doc_engine
|
||||
assert client.reqs_in_flight["query-req"] == query_engine
|
||||
assert client.reqs_in_flight["doc-req"] == doc_engine
|
||||
|
||||
|
||||
def test_dplb_non_late_interaction_still_uses_lb():
|
||||
client = object.__new__(DPLBAsyncMPClient)
|
||||
client.client_count = 1
|
||||
client.reqs_in_flight = {}
|
||||
client.core_engines = [b"\x00\x00", b"\x01\x00", b"\x02\x00"]
|
||||
client.lb_engines = [[2, 1], [0, 0], [1, 0]]
|
||||
client.eng_start_index = 0
|
||||
|
||||
request = make_request(SamplingParams(max_tokens=1))
|
||||
chosen_engine = client.get_core_engine_for_request(request)
|
||||
|
||||
assert chosen_engine == client.core_engines[1]
|
||||
assert client.lb_engines[1][0] == 1
|
||||
|
||||
|
||||
def loop_until_done(client: EngineCoreClient, outputs: dict):
|
||||
while True:
|
||||
engine_core_outputs = client.get_output().outputs
|
||||
|
||||
113
tests/v1/worker/test_late_interaction_runner.py
Normal file
113
tests/v1/worker/test_late_interaction_runner.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import LateInteractionParams, PoolingParams
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
build_late_interaction_doc_params,
|
||||
build_late_interaction_query_params,
|
||||
compute_maxsim_score,
|
||||
)
|
||||
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner
|
||||
|
||||
|
||||
def _make_pooling_params(
|
||||
late_interaction_params: LateInteractionParams,
|
||||
) -> PoolingParams:
|
||||
return PoolingParams(
|
||||
task="token_embed",
|
||||
late_interaction_params=late_interaction_params,
|
||||
)
|
||||
|
||||
|
||||
def test_postprocess_scores_and_releases_query_cache():
|
||||
runner = LateInteractionRunner()
|
||||
query_key = "query-0"
|
||||
query_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
|
||||
doc_emb = torch.tensor([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]], dtype=torch.float32)
|
||||
|
||||
query_params = _make_pooling_params(
|
||||
build_late_interaction_query_params(query_key=query_key, query_uses=1)
|
||||
)
|
||||
query_output = runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[query_emb],
|
||||
pooling_params=[query_params],
|
||||
req_ids=["query-req"],
|
||||
finished_mask=[True],
|
||||
)
|
||||
assert isinstance(query_output, list)
|
||||
assert query_output[0] is not None
|
||||
assert query_output[0].shape == torch.Size([])
|
||||
|
||||
doc_params = _make_pooling_params(
|
||||
build_late_interaction_doc_params(query_key=query_key)
|
||||
)
|
||||
doc_output = runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[doc_emb],
|
||||
pooling_params=[doc_params],
|
||||
req_ids=["doc-req"],
|
||||
finished_mask=[True],
|
||||
)
|
||||
assert isinstance(doc_output, list)
|
||||
assert doc_output[0] is not None
|
||||
assert torch.allclose(doc_output[0], compute_maxsim_score(query_emb, doc_emb))
|
||||
|
||||
with pytest.raises(ValueError, match="query cache miss"):
|
||||
runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[doc_emb],
|
||||
pooling_params=[doc_params],
|
||||
req_ids=["doc-req-2"],
|
||||
finished_mask=[True],
|
||||
)
|
||||
|
||||
|
||||
def test_finished_request_releases_unscored_doc_use():
|
||||
runner = LateInteractionRunner()
|
||||
query_key = "query-cancel"
|
||||
query_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
|
||||
doc_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
|
||||
|
||||
query_params = _make_pooling_params(
|
||||
build_late_interaction_query_params(query_key=query_key, query_uses=1)
|
||||
)
|
||||
runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[query_emb],
|
||||
pooling_params=[query_params],
|
||||
req_ids=["query-req"],
|
||||
finished_mask=[True],
|
||||
)
|
||||
|
||||
doc_params = _make_pooling_params(
|
||||
build_late_interaction_doc_params(query_key=query_key)
|
||||
)
|
||||
runner.register_request("doc-req", doc_params)
|
||||
runner.on_requests_finished({"doc-req"})
|
||||
|
||||
with pytest.raises(ValueError, match="query cache miss"):
|
||||
runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[doc_emb],
|
||||
pooling_params=[doc_params],
|
||||
req_ids=["doc-req-retry"],
|
||||
finished_mask=[True],
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_query_uses_raises():
|
||||
runner = LateInteractionRunner()
|
||||
bad_meta = LateInteractionParams(
|
||||
mode=LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
query_key="query-bad",
|
||||
)
|
||||
bad_meta.query_uses = "bad-int" # type: ignore[assignment]
|
||||
bad_query_params = _make_pooling_params(bad_meta)
|
||||
|
||||
with pytest.raises(ValueError, match="must be an integer value"):
|
||||
runner.postprocess_pooler_output(
|
||||
raw_pooler_output=[torch.ones((2, 2), dtype=torch.float32)],
|
||||
pooling_params=[bad_query_params],
|
||||
req_ids=["query-req"],
|
||||
finished_mask=[True],
|
||||
)
|
||||
@@ -225,12 +225,6 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
num_api_servers: int = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False):
|
||||
# TODO(wentao): remove this once well tested
|
||||
raise ValueError(
|
||||
"--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now"
|
||||
)
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
|
||||
@@ -281,10 +281,6 @@ class FrontendArgs(BaseFrontendArgs):
|
||||
Enable offline FastAPI documentation for air-gapped environments.
|
||||
Uses vendored static assets bundled with vLLM.
|
||||
"""
|
||||
use_gpu_for_pooling_score: bool = False
|
||||
"""If set, run pooling score MaxSim on GPU in the API server process.
|
||||
Can significantly improve late-interaction scoring performance.
|
||||
https://github.com/vllm-project/vllm/pull/35330"""
|
||||
|
||||
@classmethod
|
||||
def _customize_cli_kwargs(
|
||||
|
||||
@@ -111,7 +111,7 @@ def init_pooling_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
|
||||
else None
|
||||
|
||||
@@ -31,7 +31,6 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreInputs,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_scores,
|
||||
get_score_prompt,
|
||||
parse_score_data_single,
|
||||
validate_score_input,
|
||||
@@ -43,6 +42,10 @@ from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
build_late_interaction_doc_params,
|
||||
build_late_interaction_query_params,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -56,7 +59,6 @@ class ServingScores(OpenAIServing):
|
||||
request_logger: RequestLogger | None,
|
||||
score_template: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
use_gpu_for_pooling_score: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
@@ -64,7 +66,6 @@ class ServingScores(OpenAIServing):
|
||||
request_logger=request_logger,
|
||||
)
|
||||
self.score_template = score_template
|
||||
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
@@ -253,19 +254,30 @@ class ServingScores(OpenAIServing):
|
||||
)
|
||||
)
|
||||
|
||||
input_texts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
for text, engine_prompt in preprocessed:
|
||||
input_texts.append(text)
|
||||
engine_prompts.append(engine_prompt)
|
||||
query_prompts: list[TokensPrompt] = [
|
||||
prompt for _, prompt in preprocessed[: len(data_1)]
|
||||
]
|
||||
doc_prompts: list[TokensPrompt] = [
|
||||
prompt for _, prompt in preprocessed[len(data_1) :]
|
||||
]
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
default_pooling_params = request.to_pooling_params("token_embed")
|
||||
|
||||
pooling_params = request.to_pooling_params("token_embed")
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
# stage 1: encode queries and cache token embeddings on workers.
|
||||
query_keys = [f"{request_id}-query-{i}" for i in range(len(query_prompts))]
|
||||
query_uses = [len(doc_prompts) if len(query_prompts) == 1 else 1] * len(
|
||||
query_prompts
|
||||
)
|
||||
query_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
for i, engine_prompt in enumerate(query_prompts):
|
||||
request_id_item = f"{request_id}-query-{i}"
|
||||
pooling_params = default_pooling_params.clone()
|
||||
pooling_params.late_interaction_params = (
|
||||
build_late_interaction_query_params(
|
||||
query_key=query_keys[i],
|
||||
query_uses=query_uses[i],
|
||||
)
|
||||
)
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
@@ -274,7 +286,7 @@ class ServingScores(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
query_generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
@@ -285,53 +297,71 @@ class ServingScores(OpenAIServing):
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
query_outputs: list[PoolingRequestOutput | None] = [None] * len(query_prompts)
|
||||
if query_generators:
|
||||
async for i, res in merge_async_iterators(*query_generators):
|
||||
query_outputs[i] = res
|
||||
|
||||
# Collect token embeddings
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
|
||||
assert all(res is not None for res in query_outputs)
|
||||
query_results = [res for res in query_outputs if res is not None]
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
# Split into query and document embeddings
|
||||
emb_data_1: list[PoolingRequestOutput] = []
|
||||
emb_data_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(data_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_1.append(emb)
|
||||
|
||||
for i in range(len(data_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_2.append(emb)
|
||||
|
||||
# Expand queries if 1:N scoring
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
maxsim_scores = compute_maxsim_scores(
|
||||
[emb.outputs.data for emb in emb_data_1],
|
||||
[emb.outputs.data for emb in emb_data_2],
|
||||
use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
|
||||
# stage 2: encode docs and return scalar scores from workers.
|
||||
doc_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
for i, engine_prompt in enumerate(doc_prompts):
|
||||
request_id_item = f"{request_id}-doc-{i}"
|
||||
query_idx = 0 if len(query_prompts) == 1 else i
|
||||
pooling_params = default_pooling_params.clone()
|
||||
pooling_params.late_interaction_params = build_late_interaction_doc_params(
|
||||
query_key=query_keys[query_idx]
|
||||
)
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
doc_generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
doc_outputs: list[PoolingRequestOutput | None] = [None] * len(doc_prompts)
|
||||
if doc_generators:
|
||||
async for i, res in merge_async_iterators(*doc_generators):
|
||||
doc_outputs[i] = res
|
||||
|
||||
assert all(res is not None for res in doc_outputs)
|
||||
doc_results = [res for res in doc_outputs if res is not None]
|
||||
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
if len(query_results) == 1:
|
||||
query_results = query_results * len(doc_results)
|
||||
|
||||
for query_result, doc_result in zip(query_results, doc_results):
|
||||
tokens = (
|
||||
query_result.prompt_token_ids + padding + doc_result.prompt_token_ids
|
||||
)
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
request_id=f"{query_result.request_id}_{doc_result.request_id}",
|
||||
outputs=doc_result.outputs,
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
num_cached_tokens=(
|
||||
query_result.num_cached_tokens + doc_result.num_cached_tokens
|
||||
),
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -11,6 +11,26 @@ from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
|
||||
class LateInteractionParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True,
|
||||
): # type: ignore[call-arg]
|
||||
"""Metadata for worker-side late-interaction scoring.
|
||||
|
||||
Attributes:
|
||||
mode:
|
||||
- "cache_query": cache query token embeddings
|
||||
- "score_doc": score a document against a cached query.
|
||||
query_key: stable key used for both DP routing and worker cache lookup.
|
||||
query_uses: expected number of document requests
|
||||
"""
|
||||
|
||||
mode: str
|
||||
query_key: str
|
||||
query_uses: int | None = None
|
||||
|
||||
|
||||
class PoolingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
@@ -46,6 +66,7 @@ class PoolingParams(
|
||||
task: PoolingTask | None = None
|
||||
requires_token_ids: bool = False
|
||||
skip_reading_prefix_cache: bool | None = None
|
||||
late_interaction_params: LateInteractionParams | None = None
|
||||
extra_kwargs: dict[str, Any] | None = None
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
@@ -193,6 +214,7 @@ class PoolingParams(
|
||||
f"returned_token_ids={self.returned_token_ids}, "
|
||||
f"requires_token_ids={self.requires_token_ids}, "
|
||||
f"skip_reading_prefix_cache={self.skip_reading_prefix_cache}, "
|
||||
f"late_interaction_params={self.late_interaction_params}, "
|
||||
f"extra_kwargs={self.extra_kwargs})"
|
||||
)
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from vllm.v1.engine.utils import (
|
||||
launch_core_engines,
|
||||
)
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.pool.late_interaction import get_late_interaction_engine_index
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -1360,7 +1361,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
|
||||
def get_core_engine_for_request(self, request: EngineCoreRequest) -> EngineIdentity:
|
||||
# Engines are in rank order.
|
||||
if (eng_index := request.data_parallel_rank) is None:
|
||||
if (eng_index := request.data_parallel_rank) is None and (
|
||||
eng_index := get_late_interaction_engine_index(
|
||||
request.pooling_params, len(self.core_engines)
|
||||
)
|
||||
) is None:
|
||||
current_counts = self.lb_engines
|
||||
# TODO use P2C alg for larger DP sizes
|
||||
num_engines = len(current_counts)
|
||||
|
||||
64
vllm/v1/pool/late_interaction.py
Normal file
64
vllm/v1/pool/late_interaction.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import zlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import LateInteractionParams, PoolingParams
|
||||
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY = "cache_query"
|
||||
LATE_INTERACTION_MODE_SCORE_DOC = "score_doc"
|
||||
|
||||
|
||||
def get_late_interaction_engine_index(
|
||||
pooling_params: PoolingParams | None,
|
||||
num_engines: int,
|
||||
) -> int | None:
|
||||
if pooling_params is None or pooling_params.late_interaction_params is None:
|
||||
return None
|
||||
|
||||
late_interaction_params = pooling_params.late_interaction_params
|
||||
mode = late_interaction_params.mode
|
||||
if mode not in (
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
LATE_INTERACTION_MODE_SCORE_DOC,
|
||||
):
|
||||
return None
|
||||
|
||||
query_key = late_interaction_params.query_key
|
||||
if not isinstance(query_key, str) or not query_key:
|
||||
return None
|
||||
|
||||
# query embeddings are cached in process-local worker memory,
|
||||
# pin requests sharing the same query key to the same engine.
|
||||
return zlib.crc32(query_key.encode("utf-8")) % num_engines
|
||||
|
||||
|
||||
def build_late_interaction_query_params(
|
||||
query_key: str,
|
||||
query_uses: int,
|
||||
) -> LateInteractionParams:
|
||||
return LateInteractionParams(
|
||||
mode=LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
query_key=query_key,
|
||||
query_uses=max(1, int(query_uses)),
|
||||
)
|
||||
|
||||
|
||||
def build_late_interaction_doc_params(
|
||||
query_key: str,
|
||||
) -> LateInteractionParams:
|
||||
return LateInteractionParams(
|
||||
mode=LATE_INTERACTION_MODE_SCORE_DOC,
|
||||
query_key=query_key,
|
||||
)
|
||||
|
||||
|
||||
def compute_maxsim_score(
|
||||
q_emb: torch.Tensor,
|
||||
d_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# compute in float32 for numerical stability
|
||||
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
150
vllm/v1/worker/gpu/pool/late_interaction_runner.py
Normal file
150
vllm/v1/worker/gpu/pool/late_interaction_runner.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
LATE_INTERACTION_MODE_SCORE_DOC,
|
||||
compute_maxsim_score,
|
||||
)
|
||||
|
||||
|
||||
class LateInteractionRunner:
|
||||
"""Worker-side state and postprocessing for late-interaction scoring."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# query_key -> token embeddings for late-interaction scoring.
|
||||
self._query_cache: dict[str, torch.Tensor] = {}
|
||||
# query_key -> remaining number of docs that should use this query.
|
||||
self._query_uses: dict[str, int] = {}
|
||||
# doc request id -> query key.
|
||||
self._doc_query_keys: dict[str, str] = {}
|
||||
|
||||
def clear(self) -> None:
|
||||
self._query_cache.clear()
|
||||
self._query_uses.clear()
|
||||
self._doc_query_keys.clear()
|
||||
|
||||
def register_request(
|
||||
self, req_id: str, pooling_params: PoolingParams | None
|
||||
) -> None:
|
||||
mode, query_key, _ = self._parse_late_interaction_meta(pooling_params)
|
||||
if mode == LATE_INTERACTION_MODE_SCORE_DOC and query_key is not None:
|
||||
self._doc_query_keys[req_id] = query_key
|
||||
else:
|
||||
self._doc_query_keys.pop(req_id, None)
|
||||
|
||||
def on_requests_finished(self, finished_req_ids: Iterable[str]) -> None:
|
||||
for req_id in finished_req_ids:
|
||||
query_key = self._doc_query_keys.pop(req_id, None)
|
||||
if query_key is not None:
|
||||
self._release_query_use(query_key)
|
||||
|
||||
def postprocess_pooler_output(
|
||||
self,
|
||||
raw_pooler_output: PoolerOutput,
|
||||
pooling_params: list[PoolingParams],
|
||||
req_ids: list[str],
|
||||
finished_mask: list[bool],
|
||||
) -> PoolerOutput:
|
||||
if not isinstance(raw_pooler_output, list):
|
||||
return raw_pooler_output
|
||||
|
||||
num_reqs = len(pooling_params)
|
||||
if len(raw_pooler_output) != num_reqs:
|
||||
raise ValueError(
|
||||
"raw_pooler_output and pooling_params must have the same length."
|
||||
)
|
||||
if len(req_ids) != num_reqs:
|
||||
raise ValueError("req_ids and pooling_params must have the same length.")
|
||||
if len(finished_mask) != num_reqs:
|
||||
raise ValueError(
|
||||
"finished_mask and pooling_params must have the same length."
|
||||
)
|
||||
|
||||
if not any(finished_mask):
|
||||
return raw_pooler_output
|
||||
if not any(p.late_interaction_params is not None for p in pooling_params):
|
||||
return raw_pooler_output
|
||||
|
||||
outputs: list[torch.Tensor | None] = list(raw_pooler_output)
|
||||
for i, (req_id, output, params, finished) in enumerate(
|
||||
zip(req_ids, outputs, pooling_params, finished_mask)
|
||||
):
|
||||
if not finished or output is None:
|
||||
continue
|
||||
|
||||
mode, query_key, query_uses = self._parse_late_interaction_meta(params)
|
||||
if mode is None:
|
||||
continue
|
||||
|
||||
assert query_key is not None
|
||||
if mode == LATE_INTERACTION_MODE_CACHE_QUERY:
|
||||
assert query_uses is not None
|
||||
# `output` can be a view into the current step's hidden-states
|
||||
# buffer, so clone it before storing across scheduling steps.
|
||||
self._query_cache[query_key] = output.clone()
|
||||
self._query_uses[query_key] = query_uses
|
||||
outputs[i] = torch.zeros((), device=output.device, dtype=torch.float32)
|
||||
continue
|
||||
|
||||
if mode == LATE_INTERACTION_MODE_SCORE_DOC:
|
||||
query_output = self._query_cache.get(query_key)
|
||||
if query_output is None:
|
||||
raise ValueError(
|
||||
"late-interaction query cache miss for key "
|
||||
f"{query_key!r}. Ensure query requests are executed "
|
||||
"before their paired document requests."
|
||||
)
|
||||
|
||||
outputs[i] = compute_maxsim_score(query_output, output)
|
||||
self._doc_query_keys.pop(req_id, None)
|
||||
self._release_query_use(query_key)
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unsupported late-interaction mode: {mode!r}")
|
||||
|
||||
return outputs
|
||||
|
||||
def _release_query_use(self, query_key: str) -> None:
|
||||
remaining = self._query_uses.get(query_key, 1) - 1
|
||||
if remaining <= 0:
|
||||
self._query_uses.pop(query_key, None)
|
||||
self._query_cache.pop(query_key, None)
|
||||
else:
|
||||
self._query_uses[query_key] = remaining
|
||||
|
||||
@staticmethod
|
||||
def _parse_late_interaction_meta(
|
||||
pooling_params: PoolingParams | None,
|
||||
) -> tuple[str | None, str | None, int | None]:
|
||||
if pooling_params is None or pooling_params.late_interaction_params is None:
|
||||
return None, None, None
|
||||
|
||||
late_interaction_params = pooling_params.late_interaction_params
|
||||
mode = late_interaction_params.mode
|
||||
|
||||
query_key = late_interaction_params.query_key
|
||||
if not isinstance(query_key, str) or not query_key:
|
||||
raise ValueError(
|
||||
"late-interaction request is missing a valid query key in "
|
||||
"pooling_params.late_interaction_params."
|
||||
)
|
||||
|
||||
if mode == LATE_INTERACTION_MODE_CACHE_QUERY:
|
||||
query_uses_raw = late_interaction_params.query_uses
|
||||
if query_uses_raw is None:
|
||||
query_uses_raw = 1
|
||||
try:
|
||||
query_uses = max(1, int(query_uses_raw))
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(
|
||||
"late-interaction query uses must be an integer value."
|
||||
) from exc
|
||||
return mode, query_key, query_uses
|
||||
|
||||
return mode, query_key, None
|
||||
@@ -181,6 +181,7 @@ from vllm.v1.worker.cp_utils import (
|
||||
)
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
@@ -491,6 +492,7 @@ class GPUModelRunner(
|
||||
|
||||
# mm_hash -> encoder_output
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
self.late_interaction_runner = LateInteractionRunner()
|
||||
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
# Set up speculative decoding.
|
||||
@@ -831,6 +833,7 @@ class GPUModelRunner(
|
||||
"""
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
self.late_interaction_runner.clear()
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
@@ -839,6 +842,7 @@ class GPUModelRunner(
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_cache.clear()
|
||||
self.late_interaction_runner.clear()
|
||||
|
||||
@torch.inference_mode()
|
||||
def init_fp8_kv_scales(self) -> None:
|
||||
@@ -1002,6 +1006,9 @@ class GPUModelRunner(
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.late_interaction_runner.on_requests_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
@@ -1089,6 +1096,7 @@ class GPUModelRunner(
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
self.requests[req_id] = req_state
|
||||
self.late_interaction_runner.register_request(req_id, pooling_params)
|
||||
|
||||
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
@@ -1360,6 +1368,7 @@ class GPUModelRunner(
|
||||
req_state.prompt_embeds = new_req_data.prompt_embeds
|
||||
req_state.sampling_params = new_req_data.sampling_params
|
||||
req_state.pooling_params = new_req_data.pooling_params
|
||||
self.late_interaction_runner.register_request(req_id, req_state.pooling_params)
|
||||
req_state.block_ids = new_req_data.block_ids
|
||||
req_state.num_computed_tokens = new_req_data.num_computed_tokens
|
||||
req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
@@ -2875,6 +2884,12 @@ class GPUModelRunner(
|
||||
seq_len == prompt_len
|
||||
for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens)
|
||||
]
|
||||
raw_pooler_output = self.late_interaction_runner.postprocess_pooler_output(
|
||||
raw_pooler_output=raw_pooler_output,
|
||||
pooling_params=pooling_metadata.pooling_params,
|
||||
req_ids=self.input_batch.req_ids,
|
||||
finished_mask=finished_mask,
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids.copy(),
|
||||
|
||||
Reference in New Issue
Block a user