[OpenAI] Extend VLLMValidationError to additional validation parameters (#31870)
Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
@@ -911,7 +911,7 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_: Request, exc: RequestValidationError):
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
|
||||
param = None
|
||||
for error in exc.errors():
|
||||
|
||||
@@ -72,6 +72,7 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.sampling_params import (
|
||||
@@ -131,36 +132,6 @@ class ErrorResponse(OpenAIBaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class VLLMValidationError(ValueError):
|
||||
"""vLLM-specific validation error for request validation failures.
|
||||
|
||||
Args:
|
||||
message: The error message describing the validation failure.
|
||||
parameter: Optional parameter name that failed validation.
|
||||
value: Optional value that was rejected during validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
parameter: str | None = None,
|
||||
value: Any = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.parameter = parameter
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
base = super().__str__()
|
||||
extras = []
|
||||
if self.parameter is not None:
|
||||
extras.append(f"parameter={self.parameter}")
|
||||
if self.value is not None:
|
||||
extras.append(f"value={self.value}")
|
||||
return f"{base} ({', '.join(extras)})" if extras else base
|
||||
|
||||
|
||||
class ModelPermission(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
|
||||
@@ -140,16 +140,16 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
except TypeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
except jinja2.TemplateError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response(e)
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
@@ -754,7 +754,7 @@ class OpenAIServing:
|
||||
if isinstance(message, Exception):
|
||||
exc = message
|
||||
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
|
||||
if isinstance(exc, VLLMValidationError):
|
||||
err_type = "BadRequestError"
|
||||
|
||||
@@ -373,7 +373,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
NotImplementedError,
|
||||
) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
return self.create_error_response(e)
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request.request_id)
|
||||
if raw_request:
|
||||
|
||||
@@ -12,7 +12,7 @@ import torch
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.protocol import VLLMValidationError
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
36
vllm/exceptions.py
Normal file
36
vllm/exceptions.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Custom exceptions for vLLM."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class VLLMValidationError(ValueError):
|
||||
"""vLLM-specific validation error for request validation failures.
|
||||
|
||||
Args:
|
||||
message: The error message describing the validation failure.
|
||||
parameter: Optional parameter name that failed validation.
|
||||
value: Optional value that was rejected during validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
parameter: str | None = None,
|
||||
value: Any = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.parameter = parameter
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
base = super().__str__()
|
||||
extras = []
|
||||
if self.parameter is not None:
|
||||
extras.append(f"parameter={self.parameter}")
|
||||
if self.value is not None:
|
||||
extras.append(f"value={self.value}")
|
||||
return f"{base} ({', '.join(extras)})" if extras else base
|
||||
@@ -11,6 +11,7 @@ from typing import Annotated, Any
|
||||
import msgspec
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -393,11 +394,17 @@ class SamplingParams(
|
||||
f"{self.repetition_penalty}."
|
||||
)
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}."
|
||||
raise VLLMValidationError(
|
||||
f"temperature must be non-negative, got {self.temperature}.",
|
||||
parameter="temperature",
|
||||
value=self.temperature,
|
||||
)
|
||||
if not 0.0 < self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||
raise VLLMValidationError(
|
||||
f"top_p must be in (0, 1], got {self.top_p}.",
|
||||
parameter="top_p",
|
||||
value=self.top_p,
|
||||
)
|
||||
# quietly accept -1 as disabled, but prefer 0
|
||||
if self.top_k < -1:
|
||||
raise ValueError(
|
||||
@@ -410,7 +417,11 @@ class SamplingParams(
|
||||
if not 0.0 <= self.min_p <= 1.0:
|
||||
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
raise VLLMValidationError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.",
|
||||
parameter="max_tokens",
|
||||
value=self.max_tokens,
|
||||
)
|
||||
if self.min_tokens < 0:
|
||||
raise ValueError(
|
||||
f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
|
||||
@@ -421,24 +432,30 @@ class SamplingParams(
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}."
|
||||
)
|
||||
if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative or -1, got {self.logprobs}."
|
||||
raise VLLMValidationError(
|
||||
f"logprobs must be non-negative or -1, got {self.logprobs}.",
|
||||
parameter="logprobs",
|
||||
value=self.logprobs,
|
||||
)
|
||||
if (
|
||||
self.prompt_logprobs is not None
|
||||
and self.prompt_logprobs != -1
|
||||
and self.prompt_logprobs < 0
|
||||
):
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"prompt_logprobs must be non-negative or -1, got "
|
||||
f"{self.prompt_logprobs}."
|
||||
f"{self.prompt_logprobs}.",
|
||||
parameter="prompt_logprobs",
|
||||
value=self.prompt_logprobs,
|
||||
)
|
||||
if self.truncate_prompt_tokens is not None and (
|
||||
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
|
||||
):
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
|
||||
f"got {self.truncate_prompt_tokens}"
|
||||
f"got {self.truncate_prompt_tokens}",
|
||||
parameter="truncate_prompt_tokens",
|
||||
value=self.truncate_prompt_tokens,
|
||||
)
|
||||
assert isinstance(self.stop_token_ids, list)
|
||||
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
||||
@@ -516,12 +533,14 @@ class SamplingParams(
|
||||
if token_id < 0 or token_id > tokenizer.max_token_id
|
||||
]
|
||||
if len(invalid_token_ids) > 0:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"The model vocabulary size is {tokenizer.max_token_id + 1},"
|
||||
f" but the following tokens"
|
||||
f" were specified as bad: {invalid_token_ids}."
|
||||
f" All token id values should be integers satisfying:"
|
||||
f" 0 <= token_id <= {tokenizer.max_token_id}."
|
||||
f" 0 <= token_id <= {tokenizer.max_token_id}.",
|
||||
parameter="bad_words",
|
||||
value=self.bad_words,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections.abc import Mapping
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
@@ -83,9 +84,11 @@ class InputProcessor:
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = self.model_config.get_vocab_size()
|
||||
if num_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"Requested sample logprobs of {num_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}"
|
||||
f"which is greater than max allowed: {max_logprobs}",
|
||||
parameter="logprobs",
|
||||
value=num_logprobs,
|
||||
)
|
||||
|
||||
# Validate prompt logprobs.
|
||||
@@ -94,9 +97,11 @@ class InputProcessor:
|
||||
if num_prompt_logprobs == -1:
|
||||
num_prompt_logprobs = self.model_config.get_vocab_size()
|
||||
if num_prompt_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"Requested prompt logprobs of {num_prompt_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}"
|
||||
f"which is greater than max allowed: {max_logprobs}",
|
||||
parameter="prompt_logprobs",
|
||||
value=num_prompt_logprobs,
|
||||
)
|
||||
|
||||
def _validate_sampling_params(
|
||||
@@ -134,9 +139,11 @@ class InputProcessor:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if invalid_token_ids:
|
||||
raise ValueError(
|
||||
raise VLLMValidationError(
|
||||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}"
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
|
||||
parameter="logit_bias",
|
||||
value=invalid_token_ids,
|
||||
)
|
||||
|
||||
def _validate_supported_sampling_params(
|
||||
|
||||
Reference in New Issue
Block a user