[Frontend][1/n] Improve pooling entrypoints | classify. (#35604)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
wang.yuqi
2026-03-03 22:05:36 +08:00
committed by GitHub
parent 440f0e7dc6
commit ea463978bb
12 changed files with 889 additions and 170 deletions

View File

@@ -7,6 +7,7 @@ import warnings
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
from functools import cached_property, lru_cache, partial
from itertools import accumulate
from pathlib import Path
@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("video", placeholder)
@dataclass
class ChatTemplateConfig:
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
trust_request_chat_template: bool = False
def validate_chat_template(chat_template: Path | str | None):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:

View File

@@ -3,6 +3,7 @@
import itertools
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
import cloudpickle
@@ -40,8 +41,11 @@ from vllm.distributed.weight_transfer.base import (
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateConfig,
ChatTemplateContentFormatOption,
load_chat_template,
)
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.score.utils import (
ScoreData,
ScoreMultiModalParam,
@@ -145,6 +149,7 @@ class LLM:
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
chat_template: The chat template to apply.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
@@ -232,6 +237,7 @@ class LLM:
quantization: QuantizationMethods | None = None,
revision: str | None = None,
tokenizer_revision: str | None = None,
chat_template: Path | str | None = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
@@ -384,9 +390,16 @@ class LLM:
self.model_config = self.llm_engine.model_config
self.renderer = self.llm_engine.renderer
self.chat_template = load_chat_template(chat_template)
self.io_processor = self.llm_engine.io_processor
self.input_processor = self.llm_engine.input_processor
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
self.init_pooling_io_processors = init_pooling_io_processors(
supported_tasks=supported_tasks,
model_config=self.model_config,
renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
# Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None
@@ -1086,7 +1099,7 @@ class LLM:
"pooling model."
)
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
if isinstance(prompts, dict) and "data" in prompts:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
@@ -1120,6 +1133,31 @@ class LLM:
for p in params_seq:
if p.task is None:
p.task = "plugin"
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
)
]
else:
if pooling_params is None:
# Use default pooling params.
@@ -1137,32 +1175,36 @@ class LLM:
)
raise ValueError(msg)
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
if use_io_processor:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
if pooling_task in self.init_pooling_io_processors:
io_processor = self.init_pooling_io_processors[pooling_task]
processor_inputs = io_processor.pre_process_offline(
prompts_seq, tokenization_kwargs
)
]
seq_lora_requests = self._lora_request_to_seq(
lora_request, len(prompts_seq)
)
seq_priority = self._priority_to_seq(None, len(prompts))
self._render_and_add_requests(
prompts=processor_inputs,
params=params_seq,
lora_requests=seq_lora_requests,
priorities=seq_priority,
)
outputs = self._run_engine(
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
)
outputs = io_processor.post_process(outputs)
else:
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return outputs
def embed(

View File

@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse,
TranslationRequest,
)
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = (
| TokenizeCompletionRequest
| DetokenizeRequest
| EmbeddingCompletionRequest
| ClassificationCompletionRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = (
ChatCompletionRequest
| TokenizeChatRequest
| EmbeddingChatRequest
| ClassificationChatRequest
| PoolingChatRequest
)
@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = (
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
| ClassificationResponse
| ScoreResponse
| GenerateResponse
)
RequestT = TypeVar("RequestT", bound=AnyRequest)
@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]):
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify")
so you can easily tell “this ID came from Embedding vs Classification.”
A short string prepended to every requests ID (e.g. "embd")
so you can easily tell “this ID came from Embedding.”
"""
def __init__(
@@ -456,7 +447,7 @@ class OpenAIServing:
) -> ErrorResponse | None:
"""
Default preprocessing hook. Subclasses may override
to prepare `ctx` (classification, embedding, etc.).
to prepare `ctx` (embedding, etc.).
"""
return None
@@ -817,7 +808,7 @@ class OpenAIServing:
token_num = len(input_ids)
max_model_len = self.model_config.max_model_len
# Note: EmbeddingRequest, ClassificationRequest,
# Note: EmbeddingRequest,
# and ScoreRequest doesn't have max_tokens
if isinstance(
request,
@@ -828,8 +819,6 @@ class OpenAIServing:
ScoreTextRequest,
ScoreQueriesDocumentsRequest,
RerankRequest,
ClassificationCompletionRequest,
ClassificationChatRequest,
),
):
# Note: input length can be up to the entire model context length
@@ -839,8 +828,6 @@ class OpenAIServing:
ScoreDataRequest: "score",
ScoreTextRequest: "score",
ScoreQueriesDocumentsRequest: "score",
ClassificationCompletionRequest: "classification",
ClassificationChatRequest: "classification",
}
operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError(

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Final
from vllm import PoolingRequestOutput, PromptType
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateConfig,
ChatTemplateContentFormatOption,
ConversationMessage,
)
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
from vllm.inputs import ProcessorInputs, SingletonPrompt
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils.mistral import is_mistral_tokenizer
class PoolingIOProcessor:
def __init__(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
):
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self.model_config = model_config
self.renderer = renderer
self.chat_template = chat_template_config.chat_template
self.chat_template_content_format: Final = (
chat_template_config.chat_template_content_format
)
self.trust_request_chat_template = (
chat_template_config.trust_request_chat_template
)
def pre_process_online(self, *args, **kwargs):
raise NotImplementedError
async def pre_process_online_async(self, *args, **kwargs):
return self.pre_process_online(*args, **kwargs)
def pre_process_offline(self, *args, **kwargs):
raise NotImplementedError
async def pre_process_offline_async(self, *args, **kwargs):
return self.pre_process_offline(*args, **kwargs)
def post_process(
self, outputs: list[PoolingRequestOutput]
) -> list[PoolingRequestOutput]:
return outputs
async def post_process_async(
self, outputs: list[PoolingRequestOutput]
) -> list[PoolingRequestOutput]:
return self.post_process(outputs)
def create_pooling_params(self, request):
return request.to_pooling_params()
def _preprocess_completion_online(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = request.build_tok_params(model_config)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
def _preprocess_chat_online(
self,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
renderer = self.renderer
default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
)
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
(conversation,), (engine_prompt,) = renderer.render_chat(
[messages],
chat_params,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
return conversation, [engine_prompt]
def _preprocess_completion_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]:
renderer = self.renderer
model_config = self.model_config
prompts = prompt_to_seq(prompts)
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
)
def _validate_chat_template(
self,
request_chat_template: str | None,
chat_template_kwargs: dict[str, Any] | None,
trust_request_chat_template: bool,
):
if not trust_request_chat_template and (
request_chat_template is not None
or (
chat_template_kwargs
and chat_template_kwargs.get("chat_template") is not None
)
):
raise ValueError(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template."
)
return None

View File

@@ -0,0 +1,378 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import ClassVar, Generic, TypeVar
from fastapi import Request
from pydantic import ConfigDict
from starlette.datastructures import Headers
from starlette.responses import JSONResponse
from vllm import (
PoolingParams,
PoolingRequestOutput,
PromptType,
SamplingParams,
envs,
)
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatTemplateConfig,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
from vllm.inputs import ProcessorInputs
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import BeamSearchParams
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import merge_async_iterators
from ...utils import create_error_response
from .io_processor import PoolingIOProcessor
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
request: PoolingRequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[ProcessorInputs] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class PoolingServing:
request_id_prefix: ClassVar[str]
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.models = models
self.model_config = models.model_config
self.max_model_len = self.model_config.max_model_len
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.log_error_stack = log_error_stack
self.chat_template_config = ChatTemplateConfig(
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
trust_request_chat_template=trust_request_chat_template,
)
self.io_processor = self.init_io_processor(
model_config=models.model_config,
renderer=models.renderer,
chat_template_config=self.chat_template_config,
)
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
raise NotImplementedError
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request,
) -> JSONResponse:
try:
model_name = self.models.model_name()
request_id = (
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
)
await self._check_model(request)
ctx = PoolingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
self._validate_request(ctx)
self._maybe_get_adapters(ctx)
await self._preprocess(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
response = await self._build_response(ctx)
return JSONResponse(content=response.model_dump())
except Exception as e:
error_response = create_error_response(e)
return JSONResponse(
content=error_response.model_dump(),
status_code=error_response.error.code,
)
async def _preprocess(
self,
ctx: PoolingServeContext,
):
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
ctx.request
)
async def _prepare_generators(
self,
ctx: PoolingServeContext,
):
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self.io_processor.create_pooling_params(ctx.request)
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
async def _collect_batch(
self,
ctx: PoolingServeContext,
):
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
if ctx.result_generator is None:
raise ValueError("Result generator not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
raise ValueError("Failed to generate results for all prompts")
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
async def _build_response(
self,
ctx: PoolingServeContext,
) -> AnyPoolingResponse:
raise NotImplementedError
@staticmethod
def _base_request_id(
raw_request: Request | None, default: str | None = None
) -> str | None:
"""Pulls the request id to use from a header, if provided"""
if raw_request is not None and (
(req_id := raw_request.headers.get("X-Request-Id")) is not None
):
return req_id
return random_uuid() if default is None else default
def _is_model_supported(self, model_name: str | None) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)
async def _check_model(
self,
request: AnyPoolingRequest,
) -> ErrorResponse | None:
if self._is_model_supported(request.model):
return None
if request.model in self.models.lora_requests:
return None
if (
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
and request.model
and (load_result := await self.models.resolve_lora(request.model))
):
if isinstance(load_result, LoRARequest):
return None
if (
isinstance(load_result, ErrorResponse)
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
):
raise ValueError(load_result.error.message)
return None
def _validate_request(self, ctx: PoolingServeContext) -> None:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
if (
truncate_prompt_tokens is not None
and truncate_prompt_tokens > self.max_model_len
):
raise ValueError(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size."
)
return None
async def _get_trace_headers(
self,
headers: Headers,
) -> Mapping[str, str] | None:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
def _maybe_get_adapters(
self,
ctx: PoolingServeContext,
supports_default_mm_loras: bool = False,
):
request = ctx.request
if request.model in self.models.lora_requests:
ctx.lora_request = self.models.lora_requests[request.model]
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None:
ctx.lora_request = default_mm_lora
if self._is_model_supported(request.model):
return None
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _get_active_default_mm_loras(
self, request: AnyPoolingRequest
) -> LoRARequest | None:
"""Determine if there are any active default multimodal loras."""
# TODO: Currently this is only enabled for chat completions
# to be better aligned with only being enabled for .generate
# when run offline. It would be nice to support additional
# tasks types in the future.
message_types = self._get_message_types(request)
default_mm_loras = set()
for lora in self.models.lora_requests.values():
# Best effort match for default multimodal lora adapters;
# There is probably a better way to do this, but currently
# this matches against the set of 'types' in any content lists
# up until '_', e.g., to match audio_url -> audio
if lora.lora_name in message_types:
default_mm_loras.add(lora)
# Currently only support default modality specific loras if
# we have exactly one lora matched on the request.
if len(default_mm_loras) == 1:
return default_mm_loras.pop()
return None
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
"""Retrieve the set of types from message content dicts up
until `_`; we use this to match potential multimodal data
with default per modality loras.
"""
message_types: set[str] = set()
if not hasattr(request, "messages"):
return message_types
messages = request.messages
if messages is None or isinstance(messages, (str, bytes)):
return message_types
for message in messages:
if (
isinstance(message, dict)
and "content" in message
and isinstance(message["content"], list)
):
for content_dict in message["content"]:
if "type" in content_dict:
message_types.add(content_dict["type"].split("_")[0])
return message_types
def _log_inputs(
self,
request_id: str,
inputs: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
if self.request_logger is None:
return
components = extract_prompt_components(self.model_config, inputs)
self.request_logger.log_inputs(
request_id,
components.text,
components.token_ids,
components.embeds,
params=params,
lora_request=lora_request,
)

View File

@@ -3,16 +3,17 @@
from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.entrypoints.utils import (
create_error_response,
load_aware_call,
with_cancellation,
)
router = APIRouter()
@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest, raw_request: Request):
async def create_classify(
request: ClassificationRequest, raw_request: Request
) -> JSONResponse:
handler = classify(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
error_response = create_error_response(
message="The model does not support Classification API"
)
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
content=error_response.model_dump(),
status_code=error_response.error.code,
)
elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
return await handler(request, raw_request)

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
from vllm import PromptType
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
)
from vllm.inputs import ProcessorInputs
from vllm.renderers.inputs import TokPrompt
class ClassifyIOProcessor(PoolingIOProcessor):
def pre_process_online(
self, request: ClassificationCompletionRequest | ClassificationChatRequest
) -> list[TokPrompt] | None:
if isinstance(request, ClassificationChatRequest):
self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
_, engine_prompts = self._preprocess_chat_online(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, ClassificationCompletionRequest):
engine_prompts = self._preprocess_completion_online(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError("Invalid classification request type")
return engine_prompts
def pre_process_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]:
return self._preprocess_completion_offline(
prompts=prompts, tokenization_kwargs=tokenization_kwargs
)

View File

@@ -1,116 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Final, TypeAlias
from typing import TypeAlias
import jinja2
import numpy as np
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
from vllm import ClassificationOutput
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
from vllm.logger import init_logger
from vllm.renderers import BaseRenderer
from .io_processor import ClassifyIOProcessor
from .protocol import (
ClassificationData,
ClassificationRequest,
ClassificationResponse,
)
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput
logger = init_logger(__name__)
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
class ServingClassification(PoolingServing):
request_id_prefix = "classify"
def __init__(
def init_io_processor(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> ClassifyIOProcessor:
return ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess(
async def _build_response(
self,
ctx: ClassificationServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
) -> ClassificationResponse:
final_res_batch_checked = await self.io_processor.post_process_async(
ctx.final_res_batch
)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
ctx.request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(ctx.request, ClassificationCompletionRequest):
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = ctx.final_res_batch
items: list[ClassificationData] = []
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
data=items,
usage=usage,
)
async def create_classify(
self,
request: ClassificationRequest,
raw_request: Request,
) -> ClassificationResponse | ErrorResponse:
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
ctx = ClassificationServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
return await self.handle(ctx) # type: ignore[return-value]

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.renderers import BaseRenderer
from vllm.tasks import SupportedTask
def init_pooling_io_processors(
supported_tasks: tuple[SupportedTask, ...],
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> dict[str, PoolingIOProcessor]:
pooling_io_processors: dict[str, PoolingIOProcessor] = {}
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.io_processor import (
ClassifyIOProcessor,
)
pooling_io_processors["classify"] = ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
return pooling_io_processors

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
PoolingCompletionLikeRequest: TypeAlias = (
EmbeddingCompletionRequest
| ClassificationCompletionRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
)
PoolingChatLikeRequest: TypeAlias = (
EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
)
AnyPoolingRequest: TypeAlias = (
PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest
)
AnyPoolingResponse: TypeAlias = (
ClassificationResponse
| EmbeddingResponse
| EmbeddingBytesResponse
| PoolingResponse
| ScoreResponse
)

View File

@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask
@@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | None]
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]

View File

@@ -5,7 +5,10 @@ import asyncio
import dataclasses
import functools
import os
import sys
import traceback
from argparse import Namespace
from http import HTTPStatus
from logging import Logger
from string import Template
from typing import TYPE_CHECKING
@@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.exceptions import VLLMValidationError
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
else:
StreamOptions = object
ErrorResponse = object
ErrorInfo = object
LoRAModulePath = object
StreamOptions = object
logger = init_logger(__name__)
@@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None:
message = logo_template.substitute(colors)
lgr.info(message, version, model_name)
def create_error_response(
message: str | Exception,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
log_error_stack: bool = False,
) -> "ErrorResponse":
exc: Exception | None = None
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
if isinstance(message, Exception):
exc = message
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
else:
err_type = "InternalServerError"
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
param = None
message = str(exc)
if log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
param=param,
)
)