[Frontend] Re-enable running MaxSim on GPU (#38620)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-04-03 00:03:13 +08:00
committed by GitHub
parent d9408ffba3
commit a9b4f07ba2
12 changed files with 207 additions and 54 deletions

View File

@@ -26,13 +26,18 @@ TEXTS_2 = [
]
@pytest.fixture(scope="module")
def server():
@pytest.fixture(scope="module", params=[True, False])
def server(request):
args = [
"--max-model-len",
str(MAX_MODEL_LEN),
]
# Test run pooling score MaxSim on worker side (GPU)
# aka flash-late-interaction
if not request.param:
args += ["--no-enable-flash-late-interaction"]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server

View File

@@ -4,12 +4,12 @@
import pytest
import torch
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
from vllm.pooling_params import LateInteractionParams, PoolingParams
from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY,
build_late_interaction_doc_params,
build_late_interaction_query_params,
compute_maxsim_score,
)
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner

View File

@@ -1449,14 +1449,14 @@ class LLM:
pooling_task = io_processor.pooling_task
scoring_data = io_processor.valid_inputs(data_1, data_2)
offset = len(scoring_data.data_1)
n_queries = len(scoring_data.data_1)
ctx = OfflineInputsContext(
prompts=scoring_data,
pooling_params=pooling_params,
tokenization_kwargs=tokenization_kwargs,
chat_template=chat_template,
offset=offset,
n_queries=n_queries,
)
processor_inputs = io_processor.pre_process_offline(ctx)
@@ -1487,7 +1487,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
outputs = io_processor.post_process_offline(
ctx=OfflineOutputsContext(outputs=outputs, offset=offset),
ctx=OfflineOutputsContext(outputs=outputs, n_queries=n_queries),
)
return [ScoringRequestOutput.from_base(item) for item in outputs]

View File

@@ -278,6 +278,9 @@ class FrontendArgs(BaseFrontendArgs):
Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM.
"""
enable_flash_late_interaction: bool = True
"""If set, run pooling score MaxSim on GPU in the API server process.
Can significantly improve late-interaction scoring performance."""
@classmethod
def _customize_cli_kwargs(

View File

@@ -123,6 +123,9 @@ def init_pooling_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
enable_flash_late_interaction=getattr(
args, "enable_flash_late_interaction", True
),
)
if enable_scoring_api(supported_tasks, model_config)
else None

View File

@@ -82,9 +82,20 @@ class PoolingServing:
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
ctx = await self._init_ctx(request, raw_request)
await self.io_processor.pre_process_online_async(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
async def _init_ctx(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
):
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
await self._check_model(request)
ctx = PoolingServeContext(
@@ -96,11 +107,7 @@ class PoolingServing:
self._validate_request(ctx)
self._maybe_get_adapters(ctx)
await self.io_processor.pre_process_online_async(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
return ctx
async def _prepare_generators(
self,

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Sequence
from typing import Any, TypeAlias, cast
from typing import Any, TypeAlias
import torch.nn.functional as F
@@ -16,7 +16,7 @@ from vllm.entrypoints.pooling.typing import (
from vllm.inputs import EngineInput
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tasks import PoolingTask, ScoreType
from vllm.tasks import PoolingTask
from vllm.utils.mistral import is_mistral_tokenizer
from ...chat_utils import ChatTemplateResolutionError
@@ -34,7 +34,7 @@ ScoringServeContext: TypeAlias = PoolingServeContext[ScoringRequest]
class ScoringIOProcessor(PoolingIOProcessor):
name: ScoreType
name: str
pooling_task: PoolingTask
def __init__(self, *args, **kwargs):
@@ -63,7 +63,7 @@ class ScoringIOProcessor(PoolingIOProcessor):
class BiEncoderIOProcessor(ScoringIOProcessor):
name: ScoreType = "bi-encoder"
name = "bi-encoder"
pooling_task: PoolingTask = "embed"
#######################################
@@ -94,20 +94,17 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
)
ctx.engine_inputs = engine_inputs
ctx.intermediates = len(scoring_data.data_1)
ctx.n_queries = len(scoring_data.data_1)
def post_process_online(
self,
ctx: ScoringServeContext,
):
if ctx.final_res_batch is None:
raise ValueError("Final response batch not available")
if ctx.intermediates is None:
raise ValueError("data_1 len not available")
assert ctx.final_res_batch is not None
assert isinstance(ctx.n_queries, int)
ctx.final_res_batch = self._post_process(
outputs=ctx.final_res_batch, offset=cast(int, ctx.intermediates)
outputs=ctx.final_res_batch, n_queries=ctx.n_queries
)
#######################################
@@ -124,8 +121,8 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
assert ctx.offset is not None
return self._post_process(outputs=ctx.outputs, offset=ctx.offset)
assert ctx.n_queries is not None
return self._post_process(outputs=ctx.outputs, n_queries=ctx.n_queries)
#######################################
# helpers
@@ -145,9 +142,9 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
)
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
emb_data_1 = outputs[:offset]
emb_data_2 = outputs[offset:]
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
emb_data_1 = outputs[:n_queries]
emb_data_2 = outputs[n_queries:]
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
@@ -177,13 +174,13 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
class LateInteractionIOProcessor(BiEncoderIOProcessor):
name: ScoreType = "late-interaction"
name = "late-interaction"
pooling_task: PoolingTask = "token_embed"
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
# Split into query and document embeddings
emb_data_1 = outputs[:offset]
emb_data_2 = outputs[offset:]
emb_data_1 = outputs[:n_queries]
emb_data_2 = outputs[n_queries:]
# Expand queries if 1:N scoring
if len(emb_data_1) == 1:
@@ -217,8 +214,15 @@ class LateInteractionIOProcessor(BiEncoderIOProcessor):
return final_res_batch
class FlashLateInteractionIOProcessor(LateInteractionIOProcessor):
name = "flash-late-interaction"
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
return outputs
class CrossEncoderIOProcessor(ScoringIOProcessor):
name: ScoreType = "cross-encoder"
name = "cross-encoder"
pooling_task: PoolingTask = "classify"
def __init__(self, *args, **kwargs):
@@ -412,8 +416,12 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
return full_prompt, engine_prompt
ScoringIOProcessors: dict[ScoreType, type[ScoringIOProcessor]] = {
"bi-encoder": BiEncoderIOProcessor,
"late-interaction": LateInteractionIOProcessor,
"cross-encoder": CrossEncoderIOProcessor,
ScoringIOProcessors: dict[str, type[ScoringIOProcessor]] = {
p.name: p
for p in [
BiEncoderIOProcessor,
LateInteractionIOProcessor,
FlashLateInteractionIOProcessor,
CrossEncoderIOProcessor,
]
}

View File

@@ -1,9 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, Response
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
@@ -11,6 +13,10 @@ from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.renderers import BaseRenderer
from vllm.v1.pool.late_interaction import (
build_late_interaction_doc_params,
build_late_interaction_query_params,
)
from .io_processor import ScoringIOProcessors, ScoringServeContext
from .protocol import (
@@ -31,13 +37,30 @@ logger = init_logger(__name__)
class ServingScores(PoolingServing):
request_id_prefix = "score"
def __init__(
self,
engine_client: EngineClient,
*args,
enable_flash_late_interaction: bool = True,
**kwargs,
):
self.score_type = engine_client.model_config.score_type
self.enable_flash_late_interaction = (
self.score_type == "late-interaction" and enable_flash_late_interaction
)
super().__init__(engine_client, *args, **kwargs)
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
score_type = model_config.score_type
score_type: str = model_config.score_type
if self.enable_flash_late_interaction:
score_type = "flash-late-interaction"
assert score_type in ScoringIOProcessors
processor_cls = ScoringIOProcessors[score_type]
return processor_cls(
@@ -46,6 +69,12 @@ class ServingScores(PoolingServing):
chat_template_config=chat_template_config,
)
async def __call__(self, *args, **kwargs) -> Response:
if not self.enable_flash_late_interaction:
return await super().__call__(*args, **kwargs)
return await self.flash_late_interaction(*args, **kwargs)
async def _build_response(
self,
ctx: ScoringServeContext,
@@ -158,3 +187,106 @@ class ServingScores(PoolingServing):
)
return JSONResponse(content=response.model_dump())
###################################################################################
### Run pooling score MaxSim on worker side (GPU) in the API server process
### Can significantly improve late-interaction scoring performance.
async def flash_late_interaction(self, *args, **kwargs) -> Response:
ctx = await self._init_ctx(*args, **kwargs)
ctx.pooling_params = self.io_processor.create_pooling_params(ctx.request)
await self.io_processor.pre_process_online_async(ctx)
# stage 1: encode queries and cache token embeddings on workers.
await self._flash_late_interaction_encode_queries(ctx)
# stage 2: encode docs and return scalar scores from workers.
await self._flash_late_interaction_encode_docs(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
async def _flash_late_interaction_encode_queries(self, ctx: ScoringServeContext):
assert ctx.n_queries is not None
assert ctx.engine_inputs is not None
assert isinstance(ctx.pooling_params, PoolingParams)
n_queries = ctx.n_queries
n_docs = len(ctx.engine_inputs) - n_queries
query_engine_inputs = ctx.engine_inputs[:n_queries]
query_keys = [f"{ctx.request_id}-query-{i}" for i in range(n_queries)]
query_uses = [n_docs if n_queries == 1 else 1] * n_queries
query_pooling_params_list = []
for i in range(n_queries):
pooling_params = ctx.pooling_params.clone()
pooling_params.late_interaction_params = (
build_late_interaction_query_params(
query_key=query_keys[i],
query_uses=query_uses[i],
)
)
query_pooling_params_list.append(pooling_params)
assert (
n_queries
== len(query_pooling_params_list)
== len(query_engine_inputs)
== len(query_keys)
)
query_ctx = ScoringServeContext(
request=ctx.request,
raw_request=ctx.raw_request,
model_name=ctx.model_name,
request_id=ctx.request_id,
pooling_params=query_pooling_params_list,
prompt_request_ids=query_keys,
engine_inputs=query_engine_inputs,
)
await self._prepare_generators(query_ctx)
await self._collect_batch(query_ctx)
async def _flash_late_interaction_encode_docs(self, ctx: ScoringServeContext):
assert ctx.n_queries is not None
assert ctx.engine_inputs is not None
assert isinstance(ctx.pooling_params, PoolingParams)
n_queries = ctx.n_queries
n_docs = len(ctx.engine_inputs) - n_queries
doc_engine_inputs = ctx.engine_inputs[n_queries:]
query_keys = [f"{ctx.request_id}-query-{i}" for i in range(n_queries)]
doc_keys = [f"{ctx.request_id}-doc-{i}" for i in range(n_docs)]
doc_pooling_params_list = []
for i in range(n_docs):
query_idx = 0 if n_queries == 1 else i
pooling_params = ctx.pooling_params.clone()
pooling_params.late_interaction_params = build_late_interaction_doc_params(
query_key=query_keys[query_idx]
)
doc_pooling_params_list.append(pooling_params)
assert (
n_docs
== len(doc_pooling_params_list)
== len(doc_engine_inputs)
== len(doc_keys)
)
doc_ctx = ScoringServeContext(
request=ctx.request,
raw_request=ctx.raw_request,
model_name=ctx.model_name,
request_id=ctx.request_id,
pooling_params=doc_pooling_params_list,
prompt_request_ids=doc_keys,
engine_inputs=doc_engine_inputs,
)
await self._prepare_generators(doc_ctx)
await self._collect_batch(doc_ctx)
ctx.final_res_batch = doc_ctx.final_res_batch

View File

@@ -36,8 +36,9 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
Returns:
MaxSim score (sum over query tokens of max similarity to any doc token)
"""
# compute in float32 for numerical stability
# [query_len, doc_len]
token_scores = torch.matmul(q_emb, d_emb.T)
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
# Max over document tokens, sum over query tokens
return token_scores.amax(dim=-1).sum()

View File

@@ -83,6 +83,9 @@ class PoolingServeContext(Generic[PoolingRequestT]):
model_config = ConfigDict(arbitrary_types_allowed=True)
## for bi-encoder & late-interaction
n_queries: int | None = None
@dataclass
class OfflineInputsContext:
@@ -92,7 +95,7 @@ class OfflineInputsContext:
chat_template: str | None = None
## for bi-encoder & late-interaction
offset: int | None = None
n_queries: int | None = None
@dataclass
@@ -100,4 +103,4 @@ class OfflineOutputsContext:
outputs: list[PoolingRequestOutput]
## for bi-encoder & late-interaction
offset: int | None = None
n_queries: int | None = None

View File

@@ -56,16 +56,7 @@ def build_late_interaction_doc_params(
)
def compute_maxsim_score(
q_emb: torch.Tensor,
d_emb: torch.Tensor,
) -> torch.Tensor:
# compute in float32 for numerical stability
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
return token_scores.amax(dim=-1).sum()
def compute_maxsim_scores(
def compute_maxsim_score_batched(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 64,

View File

@@ -9,7 +9,7 @@ from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY,
LATE_INTERACTION_MODE_SCORE_DOC,
compute_maxsim_scores,
compute_maxsim_score_batched,
)
@@ -116,7 +116,7 @@ class LateInteractionRunner:
raise ValueError(f"Unsupported late-interaction mode: {mode!r}")
if score_indices:
score_values = compute_maxsim_scores(score_queries, score_docs)
score_values = compute_maxsim_score_batched(score_queries, score_docs)
for i, req_id, query_key, score in zip(
score_indices, score_req_ids, score_query_keys, score_values
):