[Frontend][3/n] Improve pooling entrypoints | scoring. (#28631)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -11,9 +11,7 @@ from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from openai.types.responses import (
|
||||
ToolChoiceFunction,
|
||||
)
|
||||
from openai.types.responses import ToolChoiceFunction
|
||||
from pydantic import ConfigDict, TypeAdapter, ValidationError
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
@@ -21,9 +19,7 @@ import vllm.envs as envs
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
BatchChatCompletionRequest,
|
||||
@@ -42,9 +38,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
GenerationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
@@ -56,14 +50,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
PoolingCompletionRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
ScoreDataRequest,
|
||||
ScoreQueriesDocumentsRequest,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
ScoreTextRequest,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
DetokenizeRequest,
|
||||
@@ -72,8 +58,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import create_error_response
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EngineInput, PromptType, TokensPrompt
|
||||
from vllm.inputs import EngineInput, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -119,8 +104,6 @@ CompletionLikeRequest: TypeAlias = (
|
||||
CompletionRequest
|
||||
| TokenizeCompletionRequest
|
||||
| DetokenizeRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| PoolingCompletionRequest
|
||||
)
|
||||
|
||||
@@ -148,7 +131,6 @@ AnyResponse: TypeAlias = (
|
||||
| TranscriptionResponse
|
||||
| TokenizeResponse
|
||||
| PoolingResponse
|
||||
| ScoreResponse
|
||||
| GenerateResponse
|
||||
)
|
||||
|
||||
@@ -692,88 +674,6 @@ class OpenAIServing:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: object,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
# Note: ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
(
|
||||
ScoreDataRequest,
|
||||
ScoreTextRequest,
|
||||
ScoreQueriesDocumentsRequest,
|
||||
RerankRequest,
|
||||
),
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
# since these requests don't generate tokens.
|
||||
if token_num > max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreDataRequest: "score",
|
||||
ScoreTextRequest: "score",
|
||||
ScoreQueriesDocumentsRequest: "score",
|
||||
}
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input prompt.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(
|
||||
request,
|
||||
(TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
|
||||
):
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
# chat completion endpoint supports max_completion_tokens
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
|
||||
# Note: input length can be up to model context length - 1 for
|
||||
# completion-like requests.
|
||||
if token_num >= max_model_len:
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{max_model_len} tokens. However, your request has "
|
||||
f"{token_num} input tokens. Please reduce the length of "
|
||||
"the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
|
||||
if max_tokens is not None and token_num + max_tokens > max_model_len:
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{max_model_len} tokens. However, you requested "
|
||||
f"{max_tokens} output tokens and your prompt contains "
|
||||
f"{token_num} input tokens, for a total of "
|
||||
f"{token_num + max_tokens} tokens "
|
||||
f"({token_num} + {max_tokens} = "
|
||||
f"{token_num + max_tokens} > {max_model_len}). "
|
||||
f"Please reduce the length of the input prompt or the "
|
||||
f"number of requested output tokens.",
|
||||
parameter="max_tokens",
|
||||
value=max_tokens,
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _validate_chat_template(
|
||||
self,
|
||||
request_chat_template: str | None,
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
@@ -13,12 +15,14 @@ from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import pybase64 as base64
|
||||
import pydantic
|
||||
import torch
|
||||
from fastapi import UploadFile
|
||||
from prometheus_client import start_http_server
|
||||
from pydantic import Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from starlette.datastructures import State
|
||||
from starlette.responses import JSONResponse
|
||||
from tqdm import tqdm
|
||||
from urllib3.util import parse_url
|
||||
|
||||
@@ -49,7 +53,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
from vllm.entrypoints.pooling.scoring.protocol import (
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
ScoreRequest,
|
||||
@@ -180,6 +184,18 @@ class BatchRequestInput(OpenAIBaseModel):
|
||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||
|
||||
|
||||
AllResponse: TypeAlias = (
|
||||
ChatCompletionResponse
|
||||
| EmbeddingResponse
|
||||
| ScoreResponse
|
||||
| RerankResponse
|
||||
| TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
)
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
# HTTP status code of the response.
|
||||
status_code: int = 200
|
||||
@@ -188,17 +204,7 @@ class BatchResponseData(OpenAIBaseModel):
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: (
|
||||
ChatCompletionResponse
|
||||
| EmbeddingResponse
|
||||
| ScoreResponse
|
||||
| RerankResponse
|
||||
| TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| None
|
||||
) = None
|
||||
body: AllResponse | None = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
@@ -536,19 +542,13 @@ async def run_request(
|
||||
except Exception as e:
|
||||
response = create_error_response(e)
|
||||
|
||||
if isinstance(
|
||||
response,
|
||||
(
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
ScoreResponse,
|
||||
RerankResponse,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose,
|
||||
TranslationResponse,
|
||||
TranslationResponseVerbose,
|
||||
),
|
||||
):
|
||||
if isinstance(response, JSONResponse):
|
||||
with contextlib.suppress(pydantic.ValidationError):
|
||||
response = TypeAdapter(AllResponse | ErrorResponse).validate_python(
|
||||
json.loads(response.body)
|
||||
)
|
||||
|
||||
if isinstance(response, AllResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
@@ -745,14 +745,14 @@ async def build_endpoint_registry(
|
||||
"score": {
|
||||
"url_matcher": lambda url: url.endswith("/score"),
|
||||
"handler_getter": lambda: (
|
||||
serving_scores.create_score if serving_scores is not None else None
|
||||
serving_scores if serving_scores is not None else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"rerank": {
|
||||
"url_matcher": lambda url: url.endswith("/rerank"),
|
||||
"handler_getter": lambda: (
|
||||
serving_scores.do_rerank if serving_scores is not None else None
|
||||
serving_scores if serving_scores is not None else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user