[Frontend][1/n] Improve pooling entrypoints | classify. (#35604)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplateConfig:
|
||||
chat_template: str | None = None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
trust_request_chat_template: bool = False
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Path | str | None):
|
||||
"""Raises if the provided chat template appears invalid."""
|
||||
if chat_template is None:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cloudpickle
|
||||
@@ -40,8 +41,11 @@ from vllm.distributed.weight_transfer.base import (
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
load_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreData,
|
||||
ScoreMultiModalParam,
|
||||
@@ -145,6 +149,7 @@ class LLM:
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id.
|
||||
chat_template: The chat template to apply.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||
reserve for the model weights, activations, and KV cache. Higher
|
||||
@@ -232,6 +237,7 @@ class LLM:
|
||||
quantization: QuantizationMethods | None = None,
|
||||
revision: str | None = None,
|
||||
tokenizer_revision: str | None = None,
|
||||
chat_template: Path | str | None = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
@@ -384,9 +390,16 @@ class LLM:
|
||||
|
||||
self.model_config = self.llm_engine.model_config
|
||||
self.renderer = self.llm_engine.renderer
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
self.io_processor = self.llm_engine.io_processor
|
||||
self.input_processor = self.llm_engine.input_processor
|
||||
|
||||
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
|
||||
self.init_pooling_io_processors = init_pooling_io_processors(
|
||||
supported_tasks=supported_tasks,
|
||||
model_config=self.model_config,
|
||||
renderer=self.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
# Cache for __repr__ to avoid repeated collective_rpc calls
|
||||
self._cached_repr: str | None = None
|
||||
|
||||
@@ -1086,7 +1099,7 @@ class LLM:
|
||||
"pooling model."
|
||||
)
|
||||
|
||||
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
@@ -1120,6 +1133,31 @@ class LLM:
|
||||
for p in params_seq:
|
||||
if p.task is None:
|
||||
p.task = "plugin"
|
||||
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(
|
||||
processed_outputs, "num_cached_tokens", 0
|
||||
),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
)
|
||||
]
|
||||
else:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
@@ -1137,32 +1175,36 @@ class LLM:
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if use_io_processor:
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(
|
||||
processed_outputs, "num_cached_tokens", 0
|
||||
),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
if pooling_task in self.init_pooling_io_processors:
|
||||
io_processor = self.init_pooling_io_processors[pooling_task]
|
||||
processor_inputs = io_processor.pre_process_offline(
|
||||
prompts_seq, tokenization_kwargs
|
||||
)
|
||||
]
|
||||
seq_lora_requests = self._lora_request_to_seq(
|
||||
lora_request, len(prompts_seq)
|
||||
)
|
||||
seq_priority = self._priority_to_seq(None, len(prompts))
|
||||
|
||||
self._render_and_add_requests(
|
||||
prompts=processor_inputs,
|
||||
params=params_seq,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(
|
||||
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
|
||||
)
|
||||
outputs = io_processor.post_process(outputs)
|
||||
else:
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return outputs
|
||||
|
||||
def embed(
|
||||
|
||||
@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = (
|
||||
| TokenizeCompletionRequest
|
||||
| DetokenizeRequest
|
||||
| EmbeddingCompletionRequest
|
||||
| ClassificationCompletionRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| PoolingCompletionRequest
|
||||
@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = (
|
||||
ChatCompletionRequest
|
||||
| TokenizeChatRequest
|
||||
| EmbeddingChatRequest
|
||||
| ClassificationChatRequest
|
||||
| PoolingChatRequest
|
||||
)
|
||||
|
||||
@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = (
|
||||
| TranscriptionResponse
|
||||
| TokenizeResponse
|
||||
| PoolingResponse
|
||||
| ClassificationResponse
|
||||
| ScoreResponse
|
||||
| GenerateResponse
|
||||
)
|
||||
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]):
|
||||
|
||||
class OpenAIServing:
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID (e.g. "embd", "classify")
|
||||
so you can easily tell “this ID came from Embedding vs Classification.”
|
||||
A short string prepended to every request’s ID (e.g. "embd")
|
||||
so you can easily tell “this ID came from Embedding.”
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -456,7 +447,7 @@ class OpenAIServing:
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Default preprocessing hook. Subclasses may override
|
||||
to prepare `ctx` (classification, embedding, etc.).
|
||||
to prepare `ctx` (embedding, etc.).
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -817,7 +808,7 @@ class OpenAIServing:
|
||||
token_num = len(input_ids)
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
# Note: EmbeddingRequest, ClassificationRequest,
|
||||
# Note: EmbeddingRequest,
|
||||
# and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
@@ -828,8 +819,6 @@ class OpenAIServing:
|
||||
ScoreTextRequest,
|
||||
ScoreQueriesDocumentsRequest,
|
||||
RerankRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationChatRequest,
|
||||
),
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
@@ -839,8 +828,6 @@ class OpenAIServing:
|
||||
ScoreDataRequest: "score",
|
||||
ScoreTextRequest: "score",
|
||||
ScoreQueriesDocumentsRequest: "score",
|
||||
ClassificationCompletionRequest: "classification",
|
||||
ClassificationChatRequest: "classification",
|
||||
}
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise VLLMValidationError(
|
||||
|
||||
189
vllm/entrypoints/pooling/base/io_processor.py
Normal file
189
vllm/entrypoints/pooling/base/io_processor.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Final
|
||||
|
||||
from vllm import PoolingRequestOutput, PromptType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
|
||||
from vllm.inputs import ProcessorInputs, SingletonPrompt
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
|
||||
class PoolingIOProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
):
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.model_config = model_config
|
||||
self.renderer = renderer
|
||||
|
||||
self.chat_template = chat_template_config.chat_template
|
||||
self.chat_template_content_format: Final = (
|
||||
chat_template_config.chat_template_content_format
|
||||
)
|
||||
self.trust_request_chat_template = (
|
||||
chat_template_config.trust_request_chat_template
|
||||
)
|
||||
|
||||
def pre_process_online(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_online_async(self, *args, **kwargs):
|
||||
return self.pre_process_online(*args, **kwargs)
|
||||
|
||||
def pre_process_offline(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_offline_async(self, *args, **kwargs):
|
||||
return self.pre_process_offline(*args, **kwargs)
|
||||
|
||||
def post_process(
|
||||
self, outputs: list[PoolingRequestOutput]
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return outputs
|
||||
|
||||
async def post_process_async(
|
||||
self, outputs: list[PoolingRequestOutput]
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return self.post_process(outputs)
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
return request.to_pooling_params()
|
||||
|
||||
def _preprocess_completion_online(
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[TokPrompt]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
if prompt_input is not None:
|
||||
prompts.extend(prompt_to_seq(prompt_input))
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = request.build_tok_params(model_config)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
def _preprocess_chat_online(
|
||||
self,
|
||||
request: RendererChatRequest,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
|
||||
renderer = self.renderer
|
||||
|
||||
default_template_kwargs = merge_kwargs(
|
||||
default_template_kwargs,
|
||||
dict(
|
||||
tools=tool_dicts,
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
)
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
chat_params = request.build_chat_params(
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(default_template_kwargs)
|
||||
|
||||
(conversation,), (engine_prompt,) = renderer.render_chat(
|
||||
[messages],
|
||||
chat_params,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
def _preprocess_completion_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = prompt_to_seq(prompts)
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
)
|
||||
|
||||
def _validate_chat_template(
|
||||
self,
|
||||
request_chat_template: str | None,
|
||||
chat_template_kwargs: dict[str, Any] | None,
|
||||
trust_request_chat_template: bool,
|
||||
):
|
||||
if not trust_request_chat_template and (
|
||||
request_chat_template is not None
|
||||
or (
|
||||
chat_template_kwargs
|
||||
and chat_template_kwargs.get("chat_template") is not None
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Chat template is passed with request, but "
|
||||
"--trust-request-chat-template is not set. "
|
||||
"Refused request with untrusted chat template."
|
||||
)
|
||||
return None
|
||||
378
vllm/entrypoints/pooling/base/serving.py
Normal file
378
vllm/entrypoints/pooling/base/serving.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import ConfigDict
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from vllm import (
|
||||
PoolingParams,
|
||||
PoolingRequestOutput,
|
||||
PromptType,
|
||||
SamplingParams,
|
||||
envs,
|
||||
)
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tracing import (
|
||||
contains_trace_headers,
|
||||
extract_trace_headers,
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
from ...utils import create_error_response
|
||||
from .io_processor import PoolingIOProcessor
|
||||
|
||||
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PoolingServeContext(Generic[PoolingRequestT]):
|
||||
request: PoolingRequestT
|
||||
raw_request: Request | None = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_prompts: list[ProcessorInputs] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
)
|
||||
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class PoolingServing:
|
||||
request_id_prefix: ClassVar[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.engine_client = engine_client
|
||||
self.models = models
|
||||
self.model_config = models.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.log_error_stack = log_error_stack
|
||||
self.chat_template_config = ChatTemplateConfig(
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
trust_request_chat_template=trust_request_chat_template,
|
||||
)
|
||||
self.io_processor = self.init_io_processor(
|
||||
model_config=models.model_config,
|
||||
renderer=models.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> PoolingIOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request,
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
model_name = self.models.model_name()
|
||||
request_id = (
|
||||
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
)
|
||||
|
||||
await self._check_model(request)
|
||||
|
||||
ctx = PoolingServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
self._validate_request(ctx)
|
||||
self._maybe_get_adapters(ctx)
|
||||
await self._preprocess(ctx)
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
response = await self._build_response(ctx)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
except Exception as e:
|
||||
error_response = create_error_response(e)
|
||||
return JSONResponse(
|
||||
content=error_response.model_dump(),
|
||||
status_code=error_response.error.code,
|
||||
)
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
|
||||
ctx.request
|
||||
)
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if ctx.raw_request is None
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
if ctx.result_generator is None:
|
||||
raise ValueError("Result generator not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
raise ValueError("Failed to generate results for all prompts")
|
||||
|
||||
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> AnyPoolingResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _base_request_id(
|
||||
raw_request: Request | None, default: str | None = None
|
||||
) -> str | None:
|
||||
"""Pulls the request id to use from a header, if provided"""
|
||||
if raw_request is not None and (
|
||||
(req_id := raw_request.headers.get("X-Request-Id")) is not None
|
||||
):
|
||||
return req_id
|
||||
|
||||
return random_uuid() if default is None else default
|
||||
|
||||
def _is_model_supported(self, model_name: str | None) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
) -> ErrorResponse | None:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in self.models.lora_requests:
|
||||
return None
|
||||
if (
|
||||
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
|
||||
and request.model
|
||||
and (load_result := await self.models.resolve_lora(request.model))
|
||||
):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if (
|
||||
isinstance(load_result, ErrorResponse)
|
||||
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
|
||||
):
|
||||
raise ValueError(load_result.error.message)
|
||||
return None
|
||||
|
||||
def _validate_request(self, ctx: PoolingServeContext) -> None:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
|
||||
|
||||
if (
|
||||
truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > self.max_model_len
|
||||
):
|
||||
raise ValueError(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size."
|
||||
)
|
||||
return None
|
||||
|
||||
async def _get_trace_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Mapping[str, str] | None:
|
||||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||
|
||||
if is_tracing_enabled:
|
||||
return extract_trace_headers(headers)
|
||||
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
return None
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
supports_default_mm_loras: bool = False,
|
||||
):
|
||||
request = ctx.request
|
||||
if request.model in self.models.lora_requests:
|
||||
ctx.lora_request = self.models.lora_requests[request.model]
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
ctx.lora_request = default_mm_lora
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _get_active_default_mm_loras(
|
||||
self, request: AnyPoolingRequest
|
||||
) -> LoRARequest | None:
|
||||
"""Determine if there are any active default multimodal loras."""
|
||||
# TODO: Currently this is only enabled for chat completions
|
||||
# to be better aligned with only being enabled for .generate
|
||||
# when run offline. It would be nice to support additional
|
||||
# tasks types in the future.
|
||||
message_types = self._get_message_types(request)
|
||||
default_mm_loras = set()
|
||||
|
||||
for lora in self.models.lora_requests.values():
|
||||
# Best effort match for default multimodal lora adapters;
|
||||
# There is probably a better way to do this, but currently
|
||||
# this matches against the set of 'types' in any content lists
|
||||
# up until '_', e.g., to match audio_url -> audio
|
||||
if lora.lora_name in message_types:
|
||||
default_mm_loras.add(lora)
|
||||
|
||||
# Currently only support default modality specific loras if
|
||||
# we have exactly one lora matched on the request.
|
||||
if len(default_mm_loras) == 1:
|
||||
return default_mm_loras.pop()
|
||||
return None
|
||||
|
||||
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
|
||||
"""Retrieve the set of types from message content dicts up
|
||||
until `_`; we use this to match potential multimodal data
|
||||
with default per modality loras.
|
||||
"""
|
||||
message_types: set[str] = set()
|
||||
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
messages = request.messages
|
||||
if messages is None or isinstance(messages, (str, bytes)):
|
||||
return message_types
|
||||
|
||||
for message in messages:
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and "content" in message
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content_dict in message["content"]:
|
||||
if "type" in content_dict:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
components = extract_prompt_components(self.model_config, inputs)
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
components.text,
|
||||
components.token_ids,
|
||||
components.embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@@ -3,16 +3,17 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from starlette.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.entrypoints.utils import (
|
||||
create_error_response,
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
||||
async def create_classify(
|
||||
request: ClassificationRequest, raw_request: Request
|
||||
) -> JSONResponse:
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
error_response = create_error_response(
|
||||
message="The model does not support Classification API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
content=error_response.model_dump(),
|
||||
status_code=error_response.error.code,
|
||||
)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
50
vllm/entrypoints/pooling/classify/io_processor.py
Normal file
50
vllm/entrypoints/pooling/classify/io_processor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from vllm import PromptType
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
)
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
|
||||
|
||||
class ClassifyIOProcessor(PoolingIOProcessor):
|
||||
def pre_process_online(
|
||||
self, request: ClassificationCompletionRequest | ClassificationChatRequest
|
||||
) -> list[TokPrompt] | None:
|
||||
if isinstance(request, ClassificationChatRequest):
|
||||
self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
_, engine_prompts = self._preprocess_chat_online(
|
||||
request,
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(request, ClassificationCompletionRequest):
|
||||
engine_prompts = self._preprocess_completion_online(
|
||||
request,
|
||||
prompt_input=request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid classification request type")
|
||||
return engine_prompts
|
||||
|
||||
def pre_process_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
return self._preprocess_completion_offline(
|
||||
prompts=prompts, tokenization_kwargs=tokenization_kwargs
|
||||
)
|
||||
@@ -1,116 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Final, TypeAlias
|
||||
from typing import TypeAlias
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
from vllm import ClassificationOutput
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
from .io_processor import ClassifyIOProcessor
|
||||
from .protocol import (
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
|
||||
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class ServingClassification(OpenAIServing):
|
||||
class ServingClassification(PoolingServing):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
def init_io_processor(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> ClassifyIOProcessor:
|
||||
return ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def _preprocess(
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
) -> ClassificationResponse:
|
||||
final_res_batch_checked = await self.io_processor.post_process_async(
|
||||
ctx.final_res_batch
|
||||
)
|
||||
|
||||
if isinstance(ctx.request, ClassificationChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=ctx.request.chat_template,
|
||||
chat_template_kwargs=ctx.request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret:
|
||||
return error_check_ret
|
||||
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
ctx.request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(ctx.request, ClassificationCompletionRequest):
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
id2label = getattr(self.model_config.hf_config, "id2label", {})
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = ctx.final_res_batch
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
model_name = self.models.model_name()
|
||||
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await self.handle(ctx) # type: ignore[return-value]
|
||||
|
||||
31
vllm/entrypoints/pooling/io_processor_factories.py
Normal file
31
vllm/entrypoints/pooling/io_processor_factories.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.tasks import SupportedTask
|
||||
|
||||
|
||||
def init_pooling_io_processors(
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> dict[str, PoolingIOProcessor]:
|
||||
pooling_io_processors: dict[str, PoolingIOProcessor] = {}
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.classify.io_processor import (
|
||||
ClassifyIOProcessor,
|
||||
)
|
||||
|
||||
pooling_io_processors["classify"] = ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
return pooling_io_processors
|
||||
51
vllm/entrypoints/pooling/typing.py
Normal file
51
vllm/entrypoints/pooling/typing.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
|
||||
PoolingCompletionLikeRequest: TypeAlias = (
|
||||
EmbeddingCompletionRequest
|
||||
| ClassificationCompletionRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| PoolingCompletionRequest
|
||||
)
|
||||
|
||||
PoolingChatLikeRequest: TypeAlias = (
|
||||
EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
|
||||
)
|
||||
|
||||
AnyPoolingRequest: TypeAlias = (
|
||||
PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest
|
||||
)
|
||||
|
||||
AnyPoolingResponse: TypeAlias = (
|
||||
ClassificationResponse
|
||||
| EmbeddingResponse
|
||||
| EmbeddingBytesResponse
|
||||
| PoolingResponse
|
||||
| ScoreResponse
|
||||
)
|
||||
@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.serve.instrumentator.basic import base
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
@@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ import asyncio
|
||||
import dataclasses
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from http import HTTPStatus
|
||||
from logging import Logger
|
||||
from string import Template
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import current_formatter_type, init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
|
||||
else:
|
||||
StreamOptions = object
|
||||
ErrorResponse = object
|
||||
ErrorInfo = object
|
||||
LoRAModulePath = object
|
||||
|
||||
StreamOptions = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None:
|
||||
message = logo_template.substitute(colors)
|
||||
|
||||
lgr.info(message, version, model_name)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str | Exception,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
||||
param: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
) -> "ErrorResponse":
|
||||
exc: Exception | None = None
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
|
||||
|
||||
if isinstance(message, Exception):
|
||||
exc = message
|
||||
|
||||
if isinstance(exc, VLLMValidationError):
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = exc.parameter
|
||||
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
|
||||
# Common validation errors from user input
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = None
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
err_type = "NotImplementedError"
|
||||
status_code = HTTPStatus.NOT_IMPLEMENTED
|
||||
param = None
|
||||
elif exc.__class__.__name__ == "TemplateError":
|
||||
# jinja2.TemplateError (avoid importing jinja2)
|
||||
err_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
param = None
|
||||
else:
|
||||
err_type = "InternalServerError"
|
||||
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
param = None
|
||||
|
||||
message = str(exc)
|
||||
|
||||
if log_error_stack:
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type is not None:
|
||||
traceback.print_exc()
|
||||
else:
|
||||
traceback.print_stack()
|
||||
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(message),
|
||||
type=err_type,
|
||||
code=status_code.value,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user