From ea463978bb987a4c15c9b51c0013d620a722aa67 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 3 Mar 2026 22:05:36 +0800 Subject: [PATCH] [Frontend][1/n] Improve pooling entrypoints | classify. (#35604) Signed-off-by: wang.yuqi Signed-off-by: wang.yuqi Co-authored-by: Cyrus Leung --- vllm/entrypoints/chat_utils.py | 8 + vllm/entrypoints/llm.py | 94 +++-- vllm/entrypoints/openai/engine/serving.py | 21 +- vllm/entrypoints/pooling/base/io_processor.py | 189 +++++++++ vllm/entrypoints/pooling/base/serving.py | 378 ++++++++++++++++++ .../pooling/classify/api_router.py | 31 +- .../pooling/classify/io_processor.py | 50 +++ vllm/entrypoints/pooling/classify/serving.py | 132 ++---- .../pooling/io_processor_factories.py | 31 ++ vllm/entrypoints/pooling/typing.py | 51 +++ vllm/entrypoints/sagemaker/api_router.py | 3 +- vllm/entrypoints/utils.py | 71 +++- 12 files changed, 889 insertions(+), 170 deletions(-) create mode 100644 vllm/entrypoints/pooling/base/io_processor.py create mode 100644 vllm/entrypoints/pooling/base/serving.py create mode 100644 vllm/entrypoints/pooling/classify/io_processor.py create mode 100644 vllm/entrypoints/pooling/io_processor_factories.py create mode 100644 vllm/entrypoints/pooling/typing.py diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c48d7bea9..1d10aa6b0 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -7,6 +7,7 @@ import warnings from abc import ABC, abstractmethod from collections import Counter, defaultdict from collections.abc import Awaitable, Callable, Iterable +from dataclasses import dataclass from functools import cached_property, lru_cache, partial from itertools import accumulate from pathlib import Path @@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): self._add_placeholder("video", placeholder) +@dataclass +class ChatTemplateConfig: + chat_template: str | None = None + chat_template_content_format: ChatTemplateContentFormatOption = "auto" + trust_request_chat_template: bool = False + + def validate_chat_template(chat_template: Path | str | None): """Raises if the provided chat template appears invalid.""" if chat_template is None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b3260f914..d5a51a6b9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,6 +3,7 @@ import itertools from collections.abc import Callable, Iterable, Sequence +from pathlib import Path from typing import TYPE_CHECKING, Any import cloudpickle @@ -40,8 +41,11 @@ from vllm.distributed.weight_transfer.base import ( from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, + ChatTemplateConfig, ChatTemplateContentFormatOption, + load_chat_template, ) +from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors from vllm.entrypoints.pooling.score.utils import ( ScoreData, ScoreMultiModalParam, @@ -145,6 +149,7 @@ class LLM: a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. + chat_template: The chat template to apply. seed: The seed to initialize the random number generator for sampling. gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. Higher @@ -232,6 +237,7 @@ class LLM: quantization: QuantizationMethods | None = None, revision: str | None = None, tokenizer_revision: str | None = None, + chat_template: Path | str | None = None, seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: float = 4, @@ -384,9 +390,16 @@ class LLM: self.model_config = self.llm_engine.model_config self.renderer = self.llm_engine.renderer + self.chat_template = load_chat_template(chat_template) self.io_processor = self.llm_engine.io_processor self.input_processor = self.llm_engine.input_processor - + self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template) + self.init_pooling_io_processors = init_pooling_io_processors( + supported_tasks=supported_tasks, + model_config=self.model_config, + renderer=self.renderer, + chat_template_config=self.chat_template_config, + ) # Cache for __repr__ to avoid repeated collective_rpc calls self._cached_repr: str | None = None @@ -1086,7 +1099,7 @@ class LLM: "pooling model." ) - if use_io_processor := (isinstance(prompts, dict) and "data" in prompts): + if isinstance(prompts, dict) and "data" in prompts: if self.io_processor is None: raise ValueError( "No IOProcessor plugin installed. Please refer " @@ -1120,6 +1133,31 @@ class LLM: for p in params_seq: if p.task is None: p.task = "plugin" + + outputs = self._run_completion( + prompts=prompts_seq, + params=params_seq, + output_type=PoolingRequestOutput, + use_tqdm=use_tqdm, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + + # get the post-processed model outputs + assert self.io_processor is not None + processed_outputs = self.io_processor.post_process(outputs) + + return [ + PoolingRequestOutput[Any]( + request_id="", + outputs=processed_outputs, + num_cached_tokens=getattr( + processed_outputs, "num_cached_tokens", 0 + ), + prompt_token_ids=[], + finished=True, + ) + ] else: if pooling_params is None: # Use default pooling params. @@ -1137,32 +1175,36 @@ class LLM: ) raise ValueError(msg) - outputs = self._run_completion( - prompts=prompts_seq, - params=params_seq, - output_type=PoolingRequestOutput, - use_tqdm=use_tqdm, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) - - if use_io_processor: - # get the post-processed model outputs - assert self.io_processor is not None - processed_outputs = self.io_processor.post_process(outputs) - - return [ - PoolingRequestOutput[Any]( - request_id="", - outputs=processed_outputs, - num_cached_tokens=getattr( - processed_outputs, "num_cached_tokens", 0 - ), - prompt_token_ids=[], - finished=True, + if pooling_task in self.init_pooling_io_processors: + io_processor = self.init_pooling_io_processors[pooling_task] + processor_inputs = io_processor.pre_process_offline( + prompts_seq, tokenization_kwargs ) - ] + seq_lora_requests = self._lora_request_to_seq( + lora_request, len(prompts_seq) + ) + seq_priority = self._priority_to_seq(None, len(prompts)) + self._render_and_add_requests( + prompts=processor_inputs, + params=params_seq, + lora_requests=seq_lora_requests, + priorities=seq_priority, + ) + + outputs = self._run_engine( + use_tqdm=use_tqdm, output_type=PoolingRequestOutput + ) + outputs = io_processor.post_process(outputs) + else: + outputs = self._run_completion( + prompts=prompts_seq, + params=params_seq, + output_type=PoolingRequestOutput, + use_tqdm=use_tqdm, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) return outputs def embed( diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 3e376ba9c..e864f562e 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( TranscriptionResponse, TranslationRequest, ) -from vllm.entrypoints.pooling.classify.protocol import ( - ClassificationChatRequest, - ClassificationCompletionRequest, - ClassificationResponse, -) from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingBytesResponse, EmbeddingChatRequest, @@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = ( | TokenizeCompletionRequest | DetokenizeRequest | EmbeddingCompletionRequest - | ClassificationCompletionRequest | RerankRequest | ScoreRequest | PoolingCompletionRequest @@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = ( ChatCompletionRequest | TokenizeChatRequest | EmbeddingChatRequest - | ClassificationChatRequest | PoolingChatRequest ) @@ -194,12 +187,10 @@ AnyResponse: TypeAlias = ( | TranscriptionResponse | TokenizeResponse | PoolingResponse - | ClassificationResponse | ScoreResponse | GenerateResponse ) - RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]): class OpenAIServing: request_id_prefix: ClassVar[str] = """ - A short string prepended to every request’s ID (e.g. "embd", "classify") - so you can easily tell “this ID came from Embedding vs Classification.” + A short string prepended to every request’s ID (e.g. "embd") + so you can easily tell “this ID came from Embedding.” """ def __init__( @@ -456,7 +447,7 @@ class OpenAIServing: ) -> ErrorResponse | None: """ Default preprocessing hook. Subclasses may override - to prepare `ctx` (classification, embedding, etc.). + to prepare `ctx` (embedding, etc.). """ return None @@ -817,7 +808,7 @@ class OpenAIServing: token_num = len(input_ids) max_model_len = self.model_config.max_model_len - # Note: EmbeddingRequest, ClassificationRequest, + # Note: EmbeddingRequest, # and ScoreRequest doesn't have max_tokens if isinstance( request, @@ -828,8 +819,6 @@ class OpenAIServing: ScoreTextRequest, ScoreQueriesDocumentsRequest, RerankRequest, - ClassificationCompletionRequest, - ClassificationChatRequest, ), ): # Note: input length can be up to the entire model context length @@ -839,8 +828,6 @@ class OpenAIServing: ScoreDataRequest: "score", ScoreTextRequest: "score", ScoreQueriesDocumentsRequest: "score", - ClassificationCompletionRequest: "classification", - ClassificationChatRequest: "classification", } operation = operations.get(type(request), "embedding generation") raise VLLMValidationError( diff --git a/vllm/entrypoints/pooling/base/io_processor.py b/vllm/entrypoints/pooling/base/io_processor.py new file mode 100644 index 000000000..254c3d64a --- /dev/null +++ b/vllm/entrypoints/pooling/base/io_processor.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable, Sequence +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Final + +from vllm import PoolingRequestOutput, PromptType +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateConfig, + ChatTemplateContentFormatOption, + ConversationMessage, +) +from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest +from vllm.inputs import ProcessorInputs, SingletonPrompt +from vllm.renderers import BaseRenderer, merge_kwargs +from vllm.renderers.inputs import TokPrompt +from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser +from vllm.utils.mistral import is_mistral_tokenizer + + +class PoolingIOProcessor: + def __init__( + self, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ): + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self.model_config = model_config + self.renderer = renderer + + self.chat_template = chat_template_config.chat_template + self.chat_template_content_format: Final = ( + chat_template_config.chat_template_content_format + ) + self.trust_request_chat_template = ( + chat_template_config.trust_request_chat_template + ) + + def pre_process_online(self, *args, **kwargs): + raise NotImplementedError + + async def pre_process_online_async(self, *args, **kwargs): + return self.pre_process_online(*args, **kwargs) + + def pre_process_offline(self, *args, **kwargs): + raise NotImplementedError + + async def pre_process_offline_async(self, *args, **kwargs): + return self.pre_process_offline(*args, **kwargs) + + def post_process( + self, outputs: list[PoolingRequestOutput] + ) -> list[PoolingRequestOutput]: + return outputs + + async def post_process_async( + self, outputs: list[PoolingRequestOutput] + ) -> list[PoolingRequestOutput]: + return self.post_process(outputs) + + def create_pooling_params(self, request): + return request.to_pooling_params() + + def _preprocess_completion_online( + self, + request: RendererRequest, + prompt_input: str | list[str] | list[int] | list[list[int]] | None, + prompt_embeds: bytes | list[bytes] | None, + ) -> list[TokPrompt]: + renderer = self.renderer + model_config = self.model_config + + prompts = list[SingletonPrompt | bytes]() + if prompt_embeds is not None: # embeds take higher priority + prompts.extend(prompt_to_seq(prompt_embeds)) + if prompt_input is not None: + prompts.extend(prompt_to_seq(prompt_input)) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = request.build_tok_params(model_config) + + return renderer.render_cmpl( + parsed_prompts, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + def _preprocess_chat_online( + self, + request: RendererChatRequest, + messages: list[ChatCompletionMessageParam], + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + default_template_kwargs: dict[str, Any] | None, + tool_dicts: list[dict[str, Any]] | None = None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, + ) -> tuple[list[ConversationMessage], list[TokPrompt]]: + renderer = self.renderer + + default_template_kwargs = merge_kwargs( + default_template_kwargs, + dict( + tools=tool_dicts, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ) + + tok_params = request.build_tok_params(self.model_config) + chat_params = request.build_chat_params( + default_template, default_template_content_format + ).with_defaults(default_template_kwargs) + + (conversation,), (engine_prompt,) = renderer.render_chat( + [messages], + chat_params, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + return conversation, [engine_prompt] + + def _preprocess_completion_offline( + self, + prompts: PromptType | Sequence[PromptType], + tokenization_kwargs: dict[str, Any] | None = None, + ) -> Sequence[ProcessorInputs]: + renderer = self.renderer + model_config = self.model_config + + prompts = prompt_to_seq(prompts) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = renderer.default_cmpl_tok_params.with_kwargs( + **(tokenization_kwargs or {}) + ) + + return renderer.render_cmpl( + parsed_prompts, + tok_params, + ) + + def _validate_chat_template( + self, + request_chat_template: str | None, + chat_template_kwargs: dict[str, Any] | None, + trust_request_chat_template: bool, + ): + if not trust_request_chat_template and ( + request_chat_template is not None + or ( + chat_template_kwargs + and chat_template_kwargs.get("chat_template") is not None + ) + ): + raise ValueError( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template." + ) + return None diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py new file mode 100644 index 000000000..813282d3d --- /dev/null +++ b/vllm/entrypoints/pooling/base/serving.py @@ -0,0 +1,378 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from collections.abc import AsyncGenerator, Mapping +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import ClassVar, Generic, TypeVar + +from fastapi import Request +from pydantic import ConfigDict +from starlette.datastructures import Headers +from starlette.responses import JSONResponse + +from vllm import ( + PoolingParams, + PoolingRequestOutput, + PromptType, + SamplingParams, + envs, +) +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ( + ChatTemplateConfig, + ChatTemplateContentFormatOption, +) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse +from vllm.inputs import ProcessorInputs +from vllm.lora.request import LoRARequest +from vllm.renderers import BaseRenderer +from vllm.renderers.inputs.preprocess import extract_prompt_components +from vllm.sampling_params import BeamSearchParams +from vllm.tracing import ( + contains_trace_headers, + extract_trace_headers, + log_tracing_disabled_warning, +) +from vllm.utils import random_uuid +from vllm.utils.async_utils import merge_async_iterators + +from ...utils import create_error_response +from .io_processor import PoolingIOProcessor + +PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest) + + +@dataclass(kw_only=True) +class PoolingServeContext(Generic[PoolingRequestT]): + request: PoolingRequestT + raw_request: Request | None = None + model_name: str + request_id: str + created_time: int = field(default_factory=lambda: int(time.time())) + lora_request: LoRARequest | None = None + engine_prompts: list[ProcessorInputs] | None = None + + result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( + None + ) + final_res_batch: list[PoolingRequestOutput] = field(default_factory=list) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class PoolingServing: + request_id_prefix: ClassVar[str] + + def __init__( + self, + engine_client: EngineClient, + models: OpenAIServingModels, + *, + request_logger: RequestLogger | None, + chat_template: str | None = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + trust_request_chat_template: bool = False, + return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, + ): + super().__init__() + self.engine_client = engine_client + self.models = models + self.model_config = models.model_config + self.max_model_len = self.model_config.max_model_len + self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids + self.log_error_stack = log_error_stack + self.chat_template_config = ChatTemplateConfig( + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + trust_request_chat_template=trust_request_chat_template, + ) + self.io_processor = self.init_io_processor( + model_config=models.model_config, + renderer=models.renderer, + chat_template_config=self.chat_template_config, + ) + + def init_io_processor( + self, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ) -> PoolingIOProcessor: + raise NotImplementedError + + async def __call__( + self, + request: AnyPoolingRequest, + raw_request: Request, + ) -> JSONResponse: + try: + model_name = self.models.model_name() + request_id = ( + f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" + ) + + await self._check_model(request) + + ctx = PoolingServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + ) + + self._validate_request(ctx) + self._maybe_get_adapters(ctx) + await self._preprocess(ctx) + await self._prepare_generators(ctx) + await self._collect_batch(ctx) + response = await self._build_response(ctx) + return JSONResponse(content=response.model_dump()) + except Exception as e: + error_response = create_error_response(e) + return JSONResponse( + content=error_response.model_dump(), + status_code=error_response.error.code, + ) + + async def _preprocess( + self, + ctx: PoolingServeContext, + ): + ctx.engine_prompts = await self.io_processor.pre_process_online_async( + ctx.request + ) + + async def _prepare_generators( + self, + ctx: PoolingServeContext, + ): + if ctx.engine_prompts is None: + raise ValueError("Engine prompts not available") + + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) + + pooling_params = self.io_processor.create_pooling_params(ctx.request) + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" + + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + ctx.result_generator = merge_async_iterators(*generators) + + async def _collect_batch( + self, + ctx: PoolingServeContext, + ): + if ctx.engine_prompts is None: + raise ValueError("Engine prompts not available") + + if ctx.result_generator is None: + raise ValueError("Result generator not available") + + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[PoolingRequestOutput | None] + final_res_batch = [None] * num_prompts + + async for i, res in ctx.result_generator: + final_res_batch[i] = res + + if None in final_res_batch: + raise ValueError("Failed to generate results for all prompts") + + ctx.final_res_batch = [res for res in final_res_batch if res is not None] + + async def _build_response( + self, + ctx: PoolingServeContext, + ) -> AnyPoolingResponse: + raise NotImplementedError + + @staticmethod + def _base_request_id( + raw_request: Request | None, default: str | None = None + ) -> str | None: + """Pulls the request id to use from a header, if provided""" + if raw_request is not None and ( + (req_id := raw_request.headers.get("X-Request-Id")) is not None + ): + return req_id + + return random_uuid() if default is None else default + + def _is_model_supported(self, model_name: str | None) -> bool: + if not model_name: + return True + return self.models.is_base_model(model_name) + + async def _check_model( + self, + request: AnyPoolingRequest, + ) -> ErrorResponse | None: + if self._is_model_supported(request.model): + return None + if request.model in self.models.lora_requests: + return None + if ( + envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING + and request.model + and (load_result := await self.models.resolve_lora(request.model)) + ): + if isinstance(load_result, LoRARequest): + return None + if ( + isinstance(load_result, ErrorResponse) + and load_result.error.code == HTTPStatus.BAD_REQUEST.value + ): + raise ValueError(load_result.error.message) + return None + + def _validate_request(self, ctx: PoolingServeContext) -> None: + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) + + if ( + truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.max_model_len + ): + raise ValueError( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size." + ) + return None + + async def _get_trace_headers( + self, + headers: Headers, + ) -> Mapping[str, str] | None: + is_tracing_enabled = await self.engine_client.is_tracing_enabled() + + if is_tracing_enabled: + return extract_trace_headers(headers) + + if contains_trace_headers(headers): + log_tracing_disabled_warning() + + return None + + def _maybe_get_adapters( + self, + ctx: PoolingServeContext, + supports_default_mm_loras: bool = False, + ): + request = ctx.request + if request.model in self.models.lora_requests: + ctx.lora_request = self.models.lora_requests[request.model] + + # Currently only support default modality specific loras + # if we have exactly one lora matched on the request. + if supports_default_mm_loras: + default_mm_lora = self._get_active_default_mm_loras(request) + if default_mm_lora is not None: + ctx.lora_request = default_mm_lora + + if self._is_model_supported(request.model): + return None + + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") + + def _get_active_default_mm_loras( + self, request: AnyPoolingRequest + ) -> LoRARequest | None: + """Determine if there are any active default multimodal loras.""" + # TODO: Currently this is only enabled for chat completions + # to be better aligned with only being enabled for .generate + # when run offline. It would be nice to support additional + # tasks types in the future. + message_types = self._get_message_types(request) + default_mm_loras = set() + + for lora in self.models.lora_requests.values(): + # Best effort match for default multimodal lora adapters; + # There is probably a better way to do this, but currently + # this matches against the set of 'types' in any content lists + # up until '_', e.g., to match audio_url -> audio + if lora.lora_name in message_types: + default_mm_loras.add(lora) + + # Currently only support default modality specific loras if + # we have exactly one lora matched on the request. + if len(default_mm_loras) == 1: + return default_mm_loras.pop() + return None + + def _get_message_types(self, request: AnyPoolingRequest) -> set[str]: + """Retrieve the set of types from message content dicts up + until `_`; we use this to match potential multimodal data + with default per modality loras. + """ + message_types: set[str] = set() + + if not hasattr(request, "messages"): + return message_types + + messages = request.messages + if messages is None or isinstance(messages, (str, bytes)): + return message_types + + for message in messages: + if ( + isinstance(message, dict) + and "content" in message + and isinstance(message["content"], list) + ): + for content_dict in message["content"]: + if "type" in content_dict: + message_types.add(content_dict["type"].split("_")[0]) + return message_types + + def _log_inputs( + self, + request_id: str, + inputs: PromptType | ProcessorInputs, + params: SamplingParams | PoolingParams | BeamSearchParams | None, + lora_request: LoRARequest | None, + ) -> None: + if self.request_logger is None: + return + + components = extract_prompt_components(self.model_config, inputs) + + self.request_logger.log_inputs( + request_id, + components.text, + components.token_ids, + components.embeds, + params=params, + lora_request=lora_request, + ) diff --git a/vllm/entrypoints/pooling/classify/api_router.py b/vllm/entrypoints/pooling/classify/api_router.py index 8a1513ebc..0e99a86fe 100644 --- a/vllm/entrypoints/pooling/classify/api_router.py +++ b/vllm/entrypoints/pooling/classify/api_router.py @@ -3,16 +3,17 @@ from fastapi import APIRouter, Depends, Request from starlette.responses import JSONResponse -from typing_extensions import assert_never -from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.protocol import ( ClassificationRequest, - ClassificationResponse, ) from vllm.entrypoints.pooling.classify.serving import ServingClassification -from vllm.entrypoints.utils import load_aware_call, with_cancellation +from vllm.entrypoints.utils import ( + create_error_response, + load_aware_call, + with_cancellation, +) router = APIRouter() @@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None: @router.post("/classify", dependencies=[Depends(validate_json_request)]) @with_cancellation @load_aware_call -async def create_classify(request: ClassificationRequest, raw_request: Request): +async def create_classify( + request: ClassificationRequest, raw_request: Request +) -> JSONResponse: handler = classify(raw_request) if handler is None: - base_server = raw_request.app.state.openai_serving_tokenization - return base_server.create_error_response( + error_response = create_error_response( message="The model does not support Classification API" ) - - try: - generator = await handler.create_classify(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) - - if isinstance(generator, ErrorResponse): return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code + content=error_response.model_dump(), + status_code=error_response.error.code, ) - elif isinstance(generator, ClassificationResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) + return await handler(request, raw_request) diff --git a/vllm/entrypoints/pooling/classify/io_processor.py b/vllm/entrypoints/pooling/classify/io_processor.py new file mode 100644 index 000000000..90d5b0e4f --- /dev/null +++ b/vllm/entrypoints/pooling/classify/io_processor.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Any + +from vllm import PromptType +from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.entrypoints.pooling.classify.protocol import ( + ClassificationChatRequest, + ClassificationCompletionRequest, +) +from vllm.inputs import ProcessorInputs +from vllm.renderers.inputs import TokPrompt + + +class ClassifyIOProcessor(PoolingIOProcessor): + def pre_process_online( + self, request: ClassificationCompletionRequest | ClassificationChatRequest + ) -> list[TokPrompt] | None: + if isinstance(request, ClassificationChatRequest): + self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + _, engine_prompts = self._preprocess_chat_online( + request, + request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, + ) + elif isinstance(request, ClassificationCompletionRequest): + engine_prompts = self._preprocess_completion_online( + request, + prompt_input=request.input, + prompt_embeds=None, + ) + else: + raise ValueError("Invalid classification request type") + return engine_prompts + + def pre_process_offline( + self, + prompts: PromptType | Sequence[PromptType], + tokenization_kwargs: dict[str, Any] | None = None, + ) -> Sequence[ProcessorInputs]: + return self._preprocess_completion_offline( + prompts=prompts, tokenization_kwargs=tokenization_kwargs + ) diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 8cdbbde6d..efd4be77c 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -1,116 +1,57 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Final, TypeAlias +from typing import TypeAlias -import jinja2 import numpy as np -from fastapi import Request -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext -from vllm.entrypoints.openai.models.serving import OpenAIServingModels -from vllm.entrypoints.pooling.classify.protocol import ( - ClassificationChatRequest, - ClassificationCompletionRequest, +from vllm import ClassificationOutput +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ChatTemplateConfig +from vllm.entrypoints.openai.engine.protocol import UsageInfo +from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing +from vllm.logger import init_logger +from vllm.renderers import BaseRenderer + +from .io_processor import ClassifyIOProcessor +from .protocol import ( ClassificationData, ClassificationRequest, ClassificationResponse, ) -from vllm.logger import init_logger -from vllm.outputs import ClassificationOutput logger = init_logger(__name__) -ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest] +ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest] -class ServingClassification(OpenAIServing): +class ServingClassification(PoolingServing): request_id_prefix = "classify" - def __init__( + def init_io_processor( self, - engine_client: EngineClient, - models: OpenAIServingModels, - *, - request_logger: RequestLogger | None, - chat_template: str | None = None, - chat_template_content_format: ChatTemplateContentFormatOption = "auto", - trust_request_chat_template: bool = False, - log_error_stack: bool = False, - ) -> None: - super().__init__( - engine_client=engine_client, - models=models, - request_logger=request_logger, - log_error_stack=log_error_stack, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ) -> ClassifyIOProcessor: + return ClassifyIOProcessor( + model_config=model_config, + renderer=renderer, + chat_template_config=chat_template_config, ) - self.chat_template = chat_template - self.chat_template_content_format: Final = chat_template_content_format - self.trust_request_chat_template = trust_request_chat_template - - async def _preprocess( + async def _build_response( self, ctx: ClassificationServeContext, - ) -> ErrorResponse | None: - """ - Process classification inputs: tokenize text, resolve adapters, - and prepare model-specific inputs. - """ - try: - ctx.lora_request = self._maybe_get_adapters(ctx.request) + ) -> ClassificationResponse: + final_res_batch_checked = await self.io_processor.post_process_async( + ctx.final_res_batch + ) - if isinstance(ctx.request, ClassificationChatRequest): - error_check_ret = self._validate_chat_template( - request_chat_template=ctx.request.chat_template, - chat_template_kwargs=ctx.request.chat_template_kwargs, - trust_request_chat_template=self.trust_request_chat_template, - ) - if error_check_ret: - return error_check_ret - - _, ctx.engine_prompts = await self._preprocess_chat( - ctx.request, - ctx.request.messages, - default_template=self.chat_template, - default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, - ) - elif isinstance(ctx.request, ClassificationCompletionRequest): - ctx.engine_prompts = await self._preprocess_completion( - ctx.request, - prompt_input=ctx.request.input, - prompt_embeds=None, - ) - else: - return self.create_error_response("Invalid classification request type") - - return None - - except (ValueError, TypeError, jinja2.TemplateError) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) - - def _build_response( - self, - ctx: ClassificationServeContext, - ) -> ClassificationResponse | ErrorResponse: - """ - Convert model outputs to a formatted classification response - with probabilities and labels. - """ id2label = getattr(self.model_config.hf_config, "id2label", {}) - - items: list[ClassificationData] = [] num_prompt_tokens = 0 - - final_res_batch_checked = ctx.final_res_batch - + items: list[ClassificationData] = [] for idx, final_res in enumerate(final_res_batch_checked): classify_res = ClassificationOutput.from_base(final_res.outputs) @@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing): data=items, usage=usage, ) - - async def create_classify( - self, - request: ClassificationRequest, - raw_request: Request, - ) -> ClassificationResponse | ErrorResponse: - model_name = self.models.model_name() - request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" - - ctx = ClassificationServeContext( - request=request, - raw_request=raw_request, - model_name=model_name, - request_id=request_id, - ) - - return await self.handle(ctx) # type: ignore[return-value] diff --git a/vllm/entrypoints/pooling/io_processor_factories.py b/vllm/entrypoints/pooling/io_processor_factories.py new file mode 100644 index 000000000..97476768c --- /dev/null +++ b/vllm/entrypoints/pooling/io_processor_factories.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ChatTemplateConfig +from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.renderers import BaseRenderer +from vllm.tasks import SupportedTask + + +def init_pooling_io_processors( + supported_tasks: tuple[SupportedTask, ...], + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, +) -> dict[str, PoolingIOProcessor]: + pooling_io_processors: dict[str, PoolingIOProcessor] = {} + + if "classify" in supported_tasks: + from vllm.entrypoints.pooling.classify.io_processor import ( + ClassifyIOProcessor, + ) + + pooling_io_processors["classify"] = ClassifyIOProcessor( + model_config=model_config, + renderer=renderer, + chat_template_config=chat_template_config, + ) + + return pooling_io_processors diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py new file mode 100644 index 000000000..87d6487ed --- /dev/null +++ b/vllm/entrypoints/pooling/typing.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TypeAlias + +from vllm.entrypoints.pooling.classify.protocol import ( + ClassificationChatRequest, + ClassificationCompletionRequest, + ClassificationResponse, +) +from vllm.entrypoints.pooling.embed.protocol import ( + EmbeddingBytesResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingResponse, +) +from vllm.entrypoints.pooling.pooling.protocol import ( + IOProcessorRequest, + PoolingChatRequest, + PoolingCompletionRequest, + PoolingResponse, +) +from vllm.entrypoints.pooling.score.protocol import ( + RerankRequest, + ScoreRequest, + ScoreResponse, +) + +PoolingCompletionLikeRequest: TypeAlias = ( + EmbeddingCompletionRequest + | ClassificationCompletionRequest + | RerankRequest + | ScoreRequest + | PoolingCompletionRequest +) + +PoolingChatLikeRequest: TypeAlias = ( + EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest +) + +AnyPoolingRequest: TypeAlias = ( + PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest +) + +AnyPoolingResponse: TypeAlias = ( + ClassificationResponse + | EmbeddingResponse + | EmbeddingBytesResponse + | PoolingResponse + | ScoreResponse +) diff --git a/vllm/entrypoints/sagemaker/api_router.py b/vllm/entrypoints/sagemaker/api_router.py index 1138225c3..32faaa02e 100644 --- a/vllm/entrypoints/sagemaker/api_router.py +++ b/vllm/entrypoints/sagemaker/api_router.py @@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.health import health from vllm.tasks import POOLING_TASKS, SupportedTask @@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) RequestType = Any -GetHandlerFn = Callable[[Request], OpenAIServing | None] +GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 34df85f37..6390a72ce 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -5,7 +5,10 @@ import asyncio import dataclasses import functools import os +import sys +import traceback from argparse import Namespace +from http import HTTPStatus from logging import Logger from string import Template from typing import TYPE_CHECKING @@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks from vllm import envs from vllm.engine.arg_utils import EngineArgs +from vllm.exceptions import VLLMValidationError from vllm.logger import current_formatter_type, init_logger from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser if TYPE_CHECKING: - from vllm.entrypoints.openai.engine.protocol import StreamOptions + from vllm.entrypoints.openai.engine.protocol import ( + ErrorInfo, + ErrorResponse, + StreamOptions, + ) from vllm.entrypoints.openai.models.protocol import LoRAModulePath else: - StreamOptions = object + ErrorResponse = object + ErrorInfo = object LoRAModulePath = object - + StreamOptions = object logger = init_logger(__name__) @@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None: message = logo_template.substitute(colors) lgr.info(message, version, model_name) + + +def create_error_response( + message: str | Exception, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + param: str | None = None, + log_error_stack: bool = False, +) -> "ErrorResponse": + exc: Exception | None = None + + from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse + + if isinstance(message, Exception): + exc = message + + if isinstance(exc, VLLMValidationError): + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = exc.parameter + elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): + # Common validation errors from user input + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + elif isinstance(exc, NotImplementedError): + err_type = "NotImplementedError" + status_code = HTTPStatus.NOT_IMPLEMENTED + param = None + elif exc.__class__.__name__ == "TemplateError": + # jinja2.TemplateError (avoid importing jinja2) + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + else: + err_type = "InternalServerError" + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + param = None + + message = str(exc) + + if log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() + + return ErrorResponse( + error=ErrorInfo( + message=sanitize_message(message), + type=err_type, + code=status_code.value, + param=param, + ) + )