[Frontend] Re-enable running MaxSim on GPU (#38620)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -26,13 +26,18 @@ TEXTS_2 = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
@pytest.fixture(scope="module", params=[True, False])
|
||||
def server(request):
|
||||
args = [
|
||||
"--max-model-len",
|
||||
str(MAX_MODEL_LEN),
|
||||
]
|
||||
|
||||
# Test run pooling score MaxSim on worker side (GPU)
|
||||
# aka flash-late-interaction
|
||||
if not request.param:
|
||||
args += ["--no-enable-flash-late-interaction"]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -4,12 +4,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
from vllm.pooling_params import LateInteractionParams, PoolingParams
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
build_late_interaction_doc_params,
|
||||
build_late_interaction_query_params,
|
||||
compute_maxsim_score,
|
||||
)
|
||||
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner
|
||||
|
||||
|
||||
@@ -1449,14 +1449,14 @@ class LLM:
|
||||
|
||||
pooling_task = io_processor.pooling_task
|
||||
scoring_data = io_processor.valid_inputs(data_1, data_2)
|
||||
offset = len(scoring_data.data_1)
|
||||
n_queries = len(scoring_data.data_1)
|
||||
|
||||
ctx = OfflineInputsContext(
|
||||
prompts=scoring_data,
|
||||
pooling_params=pooling_params,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
chat_template=chat_template,
|
||||
offset=offset,
|
||||
n_queries=n_queries,
|
||||
)
|
||||
|
||||
processor_inputs = io_processor.pre_process_offline(ctx)
|
||||
@@ -1487,7 +1487,7 @@ class LLM:
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
|
||||
outputs = io_processor.post_process_offline(
|
||||
ctx=OfflineOutputsContext(outputs=outputs, offset=offset),
|
||||
ctx=OfflineOutputsContext(outputs=outputs, n_queries=n_queries),
|
||||
)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in outputs]
|
||||
|
||||
@@ -278,6 +278,9 @@ class FrontendArgs(BaseFrontendArgs):
|
||||
Enable offline FastAPI documentation for air-gapped environments.
|
||||
Uses vendored static assets bundled with vLLM.
|
||||
"""
|
||||
enable_flash_late_interaction: bool = True
|
||||
"""If set, run pooling score MaxSim on GPU in the API server process.
|
||||
Can significantly improve late-interaction scoring performance."""
|
||||
|
||||
@classmethod
|
||||
def _customize_cli_kwargs(
|
||||
|
||||
@@ -123,6 +123,9 @@ def init_pooling_state(
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
enable_flash_late_interaction=getattr(
|
||||
args, "enable_flash_late_interaction", True
|
||||
),
|
||||
)
|
||||
if enable_scoring_api(supported_tasks, model_config)
|
||||
else None
|
||||
|
||||
@@ -82,9 +82,20 @@ class PoolingServing:
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> Response:
|
||||
ctx = await self._init_ctx(request, raw_request)
|
||||
await self.io_processor.pre_process_online_async(ctx)
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
await self.io_processor.post_process_online_async(ctx)
|
||||
return await self._build_response(ctx)
|
||||
|
||||
async def _init_ctx(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request | None = None,
|
||||
):
|
||||
model_name = self.models.model_name()
|
||||
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
|
||||
await self._check_model(request)
|
||||
|
||||
ctx = PoolingServeContext(
|
||||
@@ -96,11 +107,7 @@ class PoolingServing:
|
||||
|
||||
self._validate_request(ctx)
|
||||
self._maybe_get_adapters(ctx)
|
||||
await self.io_processor.pre_process_online_async(ctx)
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
await self.io_processor.post_process_online_async(ctx)
|
||||
return await self._build_response(ctx)
|
||||
return ctx
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -16,7 +16,7 @@ from vllm.entrypoints.pooling.typing import (
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.renderers.hf import safe_apply_chat_template
|
||||
from vllm.tasks import PoolingTask, ScoreType
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
from ...chat_utils import ChatTemplateResolutionError
|
||||
@@ -34,7 +34,7 @@ ScoringServeContext: TypeAlias = PoolingServeContext[ScoringRequest]
|
||||
|
||||
|
||||
class ScoringIOProcessor(PoolingIOProcessor):
|
||||
name: ScoreType
|
||||
name: str
|
||||
pooling_task: PoolingTask
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -63,7 +63,7 @@ class ScoringIOProcessor(PoolingIOProcessor):
|
||||
|
||||
|
||||
class BiEncoderIOProcessor(ScoringIOProcessor):
|
||||
name: ScoreType = "bi-encoder"
|
||||
name = "bi-encoder"
|
||||
pooling_task: PoolingTask = "embed"
|
||||
|
||||
#######################################
|
||||
@@ -94,20 +94,17 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
|
||||
)
|
||||
|
||||
ctx.engine_inputs = engine_inputs
|
||||
ctx.intermediates = len(scoring_data.data_1)
|
||||
ctx.n_queries = len(scoring_data.data_1)
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: ScoringServeContext,
|
||||
):
|
||||
if ctx.final_res_batch is None:
|
||||
raise ValueError("Final response batch not available")
|
||||
|
||||
if ctx.intermediates is None:
|
||||
raise ValueError("data_1 len not available")
|
||||
assert ctx.final_res_batch is not None
|
||||
assert isinstance(ctx.n_queries, int)
|
||||
|
||||
ctx.final_res_batch = self._post_process(
|
||||
outputs=ctx.final_res_batch, offset=cast(int, ctx.intermediates)
|
||||
outputs=ctx.final_res_batch, n_queries=ctx.n_queries
|
||||
)
|
||||
|
||||
#######################################
|
||||
@@ -124,8 +121,8 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
|
||||
self,
|
||||
ctx: OfflineOutputsContext,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
assert ctx.offset is not None
|
||||
return self._post_process(outputs=ctx.outputs, offset=ctx.offset)
|
||||
assert ctx.n_queries is not None
|
||||
return self._post_process(outputs=ctx.outputs, n_queries=ctx.n_queries)
|
||||
|
||||
#######################################
|
||||
# helpers
|
||||
@@ -145,9 +142,9 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
|
||||
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
|
||||
)
|
||||
|
||||
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
|
||||
emb_data_1 = outputs[:offset]
|
||||
emb_data_2 = outputs[offset:]
|
||||
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
|
||||
emb_data_1 = outputs[:n_queries]
|
||||
emb_data_2 = outputs[n_queries:]
|
||||
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
@@ -177,13 +174,13 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
|
||||
|
||||
|
||||
class LateInteractionIOProcessor(BiEncoderIOProcessor):
|
||||
name: ScoreType = "late-interaction"
|
||||
name = "late-interaction"
|
||||
pooling_task: PoolingTask = "token_embed"
|
||||
|
||||
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
|
||||
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
|
||||
# Split into query and document embeddings
|
||||
emb_data_1 = outputs[:offset]
|
||||
emb_data_2 = outputs[offset:]
|
||||
emb_data_1 = outputs[:n_queries]
|
||||
emb_data_2 = outputs[n_queries:]
|
||||
|
||||
# Expand queries if 1:N scoring
|
||||
if len(emb_data_1) == 1:
|
||||
@@ -217,8 +214,15 @@ class LateInteractionIOProcessor(BiEncoderIOProcessor):
|
||||
return final_res_batch
|
||||
|
||||
|
||||
class FlashLateInteractionIOProcessor(LateInteractionIOProcessor):
|
||||
name = "flash-late-interaction"
|
||||
|
||||
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
|
||||
return outputs
|
||||
|
||||
|
||||
class CrossEncoderIOProcessor(ScoringIOProcessor):
|
||||
name: ScoreType = "cross-encoder"
|
||||
name = "cross-encoder"
|
||||
pooling_task: PoolingTask = "classify"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -412,8 +416,12 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
|
||||
return full_prompt, engine_prompt
|
||||
|
||||
|
||||
ScoringIOProcessors: dict[ScoreType, type[ScoringIOProcessor]] = {
|
||||
"bi-encoder": BiEncoderIOProcessor,
|
||||
"late-interaction": LateInteractionIOProcessor,
|
||||
"cross-encoder": CrossEncoderIOProcessor,
|
||||
ScoringIOProcessors: dict[str, type[ScoringIOProcessor]] = {
|
||||
p.name: p
|
||||
for p in [
|
||||
BiEncoderIOProcessor,
|
||||
LateInteractionIOProcessor,
|
||||
FlashLateInteractionIOProcessor,
|
||||
CrossEncoderIOProcessor,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
@@ -11,6 +13,10 @@ from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
build_late_interaction_doc_params,
|
||||
build_late_interaction_query_params,
|
||||
)
|
||||
|
||||
from .io_processor import ScoringIOProcessors, ScoringServeContext
|
||||
from .protocol import (
|
||||
@@ -31,13 +37,30 @@ logger = init_logger(__name__)
|
||||
class ServingScores(PoolingServing):
|
||||
request_id_prefix = "score"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
*args,
|
||||
enable_flash_late_interaction: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.score_type = engine_client.model_config.score_type
|
||||
self.enable_flash_late_interaction = (
|
||||
self.score_type == "late-interaction" and enable_flash_late_interaction
|
||||
)
|
||||
|
||||
super().__init__(engine_client, *args, **kwargs)
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> PoolingIOProcessor:
|
||||
score_type = model_config.score_type
|
||||
score_type: str = model_config.score_type
|
||||
if self.enable_flash_late_interaction:
|
||||
score_type = "flash-late-interaction"
|
||||
|
||||
assert score_type in ScoringIOProcessors
|
||||
processor_cls = ScoringIOProcessors[score_type]
|
||||
return processor_cls(
|
||||
@@ -46,6 +69,12 @@ class ServingScores(PoolingServing):
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> Response:
|
||||
if not self.enable_flash_late_interaction:
|
||||
return await super().__call__(*args, **kwargs)
|
||||
|
||||
return await self.flash_late_interaction(*args, **kwargs)
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: ScoringServeContext,
|
||||
@@ -158,3 +187,106 @@ class ServingScores(PoolingServing):
|
||||
)
|
||||
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
###################################################################################
|
||||
### Run pooling score MaxSim on worker side (GPU) in the API server process
|
||||
### Can significantly improve late-interaction scoring performance.
|
||||
|
||||
async def flash_late_interaction(self, *args, **kwargs) -> Response:
|
||||
ctx = await self._init_ctx(*args, **kwargs)
|
||||
ctx.pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
||||
await self.io_processor.pre_process_online_async(ctx)
|
||||
|
||||
# stage 1: encode queries and cache token embeddings on workers.
|
||||
await self._flash_late_interaction_encode_queries(ctx)
|
||||
# stage 2: encode docs and return scalar scores from workers.
|
||||
await self._flash_late_interaction_encode_docs(ctx)
|
||||
|
||||
await self.io_processor.post_process_online_async(ctx)
|
||||
return await self._build_response(ctx)
|
||||
|
||||
async def _flash_late_interaction_encode_queries(self, ctx: ScoringServeContext):
|
||||
assert ctx.n_queries is not None
|
||||
assert ctx.engine_inputs is not None
|
||||
assert isinstance(ctx.pooling_params, PoolingParams)
|
||||
|
||||
n_queries = ctx.n_queries
|
||||
n_docs = len(ctx.engine_inputs) - n_queries
|
||||
query_engine_inputs = ctx.engine_inputs[:n_queries]
|
||||
|
||||
query_keys = [f"{ctx.request_id}-query-{i}" for i in range(n_queries)]
|
||||
query_uses = [n_docs if n_queries == 1 else 1] * n_queries
|
||||
|
||||
query_pooling_params_list = []
|
||||
for i in range(n_queries):
|
||||
pooling_params = ctx.pooling_params.clone()
|
||||
pooling_params.late_interaction_params = (
|
||||
build_late_interaction_query_params(
|
||||
query_key=query_keys[i],
|
||||
query_uses=query_uses[i],
|
||||
)
|
||||
)
|
||||
query_pooling_params_list.append(pooling_params)
|
||||
|
||||
assert (
|
||||
n_queries
|
||||
== len(query_pooling_params_list)
|
||||
== len(query_engine_inputs)
|
||||
== len(query_keys)
|
||||
)
|
||||
|
||||
query_ctx = ScoringServeContext(
|
||||
request=ctx.request,
|
||||
raw_request=ctx.raw_request,
|
||||
model_name=ctx.model_name,
|
||||
request_id=ctx.request_id,
|
||||
pooling_params=query_pooling_params_list,
|
||||
prompt_request_ids=query_keys,
|
||||
engine_inputs=query_engine_inputs,
|
||||
)
|
||||
|
||||
await self._prepare_generators(query_ctx)
|
||||
await self._collect_batch(query_ctx)
|
||||
|
||||
async def _flash_late_interaction_encode_docs(self, ctx: ScoringServeContext):
|
||||
assert ctx.n_queries is not None
|
||||
assert ctx.engine_inputs is not None
|
||||
assert isinstance(ctx.pooling_params, PoolingParams)
|
||||
|
||||
n_queries = ctx.n_queries
|
||||
n_docs = len(ctx.engine_inputs) - n_queries
|
||||
doc_engine_inputs = ctx.engine_inputs[n_queries:]
|
||||
|
||||
query_keys = [f"{ctx.request_id}-query-{i}" for i in range(n_queries)]
|
||||
doc_keys = [f"{ctx.request_id}-doc-{i}" for i in range(n_docs)]
|
||||
|
||||
doc_pooling_params_list = []
|
||||
for i in range(n_docs):
|
||||
query_idx = 0 if n_queries == 1 else i
|
||||
pooling_params = ctx.pooling_params.clone()
|
||||
pooling_params.late_interaction_params = build_late_interaction_doc_params(
|
||||
query_key=query_keys[query_idx]
|
||||
)
|
||||
doc_pooling_params_list.append(pooling_params)
|
||||
|
||||
assert (
|
||||
n_docs
|
||||
== len(doc_pooling_params_list)
|
||||
== len(doc_engine_inputs)
|
||||
== len(doc_keys)
|
||||
)
|
||||
|
||||
doc_ctx = ScoringServeContext(
|
||||
request=ctx.request,
|
||||
raw_request=ctx.raw_request,
|
||||
model_name=ctx.model_name,
|
||||
request_id=ctx.request_id,
|
||||
pooling_params=doc_pooling_params_list,
|
||||
prompt_request_ids=doc_keys,
|
||||
engine_inputs=doc_engine_inputs,
|
||||
)
|
||||
|
||||
await self._prepare_generators(doc_ctx)
|
||||
await self._collect_batch(doc_ctx)
|
||||
|
||||
ctx.final_res_batch = doc_ctx.final_res_batch
|
||||
|
||||
@@ -36,8 +36,9 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
|
||||
Returns:
|
||||
MaxSim score (sum over query tokens of max similarity to any doc token)
|
||||
"""
|
||||
# compute in float32 for numerical stability
|
||||
# [query_len, doc_len]
|
||||
token_scores = torch.matmul(q_emb, d_emb.T)
|
||||
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
|
||||
# Max over document tokens, sum over query tokens
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
|
||||
|
||||
@@ -83,6 +83,9 @@ class PoolingServeContext(Generic[PoolingRequestT]):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
## for bi-encoder & late-interaction
|
||||
n_queries: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineInputsContext:
|
||||
@@ -92,7 +95,7 @@ class OfflineInputsContext:
|
||||
chat_template: str | None = None
|
||||
|
||||
## for bi-encoder & late-interaction
|
||||
offset: int | None = None
|
||||
n_queries: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -100,4 +103,4 @@ class OfflineOutputsContext:
|
||||
outputs: list[PoolingRequestOutput]
|
||||
|
||||
## for bi-encoder & late-interaction
|
||||
offset: int | None = None
|
||||
n_queries: int | None = None
|
||||
|
||||
@@ -56,16 +56,7 @@ def build_late_interaction_doc_params(
|
||||
)
|
||||
|
||||
|
||||
def compute_maxsim_score(
|
||||
q_emb: torch.Tensor,
|
||||
d_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# compute in float32 for numerical stability
|
||||
token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
|
||||
|
||||
def compute_maxsim_scores(
|
||||
def compute_maxsim_score_batched(
|
||||
q_embs: Sequence[torch.Tensor],
|
||||
d_embs: Sequence[torch.Tensor],
|
||||
max_batch_size: int = 64,
|
||||
|
||||
@@ -9,7 +9,7 @@ from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
LATE_INTERACTION_MODE_CACHE_QUERY,
|
||||
LATE_INTERACTION_MODE_SCORE_DOC,
|
||||
compute_maxsim_scores,
|
||||
compute_maxsim_score_batched,
|
||||
)
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ class LateInteractionRunner:
|
||||
raise ValueError(f"Unsupported late-interaction mode: {mode!r}")
|
||||
|
||||
if score_indices:
|
||||
score_values = compute_maxsim_scores(score_queries, score_docs)
|
||||
score_values = compute_maxsim_score_batched(score_queries, score_docs)
|
||||
for i, req_id, query_key, score in zip(
|
||||
score_indices, score_req_ids, score_query_keys, score_values
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user