[Frontend][3/n] Improve pooling entrypoints | scoring. (#28631)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-03-31 15:52:00 +08:00
committed by GitHub
parent f09daea261
commit d9d21eb8e3
37 changed files with 1256 additions and 1779 deletions

View File

@@ -11,9 +11,7 @@ from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
import numpy as np
from fastapi import Request
from openai.types.responses import (
ToolChoiceFunction,
)
from openai.types.responses import ToolChoiceFunction
from pydantic import ConfigDict, TypeAdapter, ValidationError
from starlette.datastructures import Headers
@@ -21,9 +19,7 @@ import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
@@ -42,9 +38,7 @@ from vllm.entrypoints.openai.engine.protocol import (
GenerationError,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
@@ -56,14 +50,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreDataRequest,
ScoreQueriesDocumentsRequest,
ScoreRequest,
ScoreResponse,
ScoreTextRequest,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
@@ -72,8 +58,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse,
)
from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EngineInput, PromptType, TokensPrompt
from vllm.inputs import EngineInput, PromptType
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
@@ -119,8 +104,6 @@ CompletionLikeRequest: TypeAlias = (
CompletionRequest
| TokenizeCompletionRequest
| DetokenizeRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
)
@@ -148,7 +131,6 @@ AnyResponse: TypeAlias = (
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
| ScoreResponse
| GenerateResponse
)
@@ -692,88 +674,6 @@ class OpenAIServing:
message_types.add(content_dict["type"].split("_")[0])
return message_types
def _validate_input(
self,
request: object,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
token_num = len(input_ids)
max_model_len = self.model_config.max_model_len
# Note: ScoreRequest doesn't have max_tokens
if isinstance(
request,
(
ScoreDataRequest,
ScoreTextRequest,
ScoreQueriesDocumentsRequest,
RerankRequest,
),
):
# Note: input length can be up to the entire model context length
# since these requests don't generate tokens.
if token_num > max_model_len:
operations: dict[type[AnyRequest], str] = {
ScoreDataRequest: "score",
ScoreTextRequest: "score",
ScoreQueriesDocumentsRequest: "score",
}
operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input prompt.",
parameter="input_tokens",
value=token_num,
)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(
request,
(TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
):
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = getattr(request, "max_tokens", None)
# Note: input length can be up to model context length - 1 for
# completion-like requests.
if token_num >= max_model_len:
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, your request has "
f"{token_num} input tokens. Please reduce the length of "
"the input messages.",
parameter="input_tokens",
value=token_num,
)
if max_tokens is not None and token_num + max_tokens > max_model_len:
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, you requested "
f"{max_tokens} output tokens and your prompt contains "
f"{token_num} input tokens, for a total of "
f"{token_num + max_tokens} tokens "
f"({token_num} + {max_tokens} = "
f"{token_num + max_tokens} > {max_model_len}). "
f"Please reduce the length of the input prompt or the "
f"number of requested output tokens.",
parameter="max_tokens",
value=max_tokens,
)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _validate_chat_template(
self,
request_chat_template: str | None,

View File

@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import json
import sys
import tempfile
from argparse import Namespace
@@ -13,12 +15,14 @@ from urllib.parse import urlparse
import aiohttp
import pybase64 as base64
import pydantic
import torch
from fastapi import UploadFile
from prometheus_client import start_http_server
from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo
from starlette.datastructures import State
from starlette.responses import JSONResponse
from tqdm import tqdm
from urllib3.util import parse_url
@@ -49,7 +53,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
from vllm.entrypoints.pooling.scoring.protocol import (
RerankRequest,
RerankResponse,
ScoreRequest,
@@ -180,6 +184,18 @@ class BatchRequestInput(OpenAIBaseModel):
return TypeAdapter(BatchRequestInputBody).validate_python(value)
AllResponse: TypeAlias = (
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
)
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.
status_code: int = 200
@@ -188,17 +204,7 @@ class BatchResponseData(OpenAIBaseModel):
request_id: str
# The body of the response.
body: (
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| None
) = None
body: AllResponse | None = None
class BatchRequestOutput(OpenAIBaseModel):
@@ -536,19 +542,13 @@ async def run_request(
except Exception as e:
response = create_error_response(e)
if isinstance(
response,
(
ChatCompletionResponse,
EmbeddingResponse,
ScoreResponse,
RerankResponse,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationResponse,
TranslationResponseVerbose,
),
):
if isinstance(response, JSONResponse):
with contextlib.suppress(pydantic.ValidationError):
response = TypeAdapter(AllResponse | ErrorResponse).validate_python(
json.loads(response.body)
)
if isinstance(response, AllResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
@@ -745,14 +745,14 @@ async def build_endpoint_registry(
"score": {
"url_matcher": lambda url: url.endswith("/score"),
"handler_getter": lambda: (
serving_scores.create_score if serving_scores is not None else None
serving_scores if serving_scores is not None else None
),
"wrapper_fn": None,
},
"rerank": {
"url_matcher": lambda url: url.endswith("/rerank"),
"handler_getter": lambda: (
serving_scores.do_rerank if serving_scores is not None else None
serving_scores if serving_scores is not None else None
),
"wrapper_fn": None,
},