[OpenAI] Extend VLLMValidationError to additional validation parameters (#31870)

Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
R3hankhan
2026-01-07 20:15:49 +05:30
committed by GitHub
parent b665bbc2d4
commit 1ab055efe6
9 changed files with 89 additions and 56 deletions

View File

@@ -911,7 +911,7 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_: Request, exc: RequestValidationError):
from vllm.entrypoints.openai.protocol import VLLMValidationError
from vllm.exceptions import VLLMValidationError
param = None
for error in exc.errors():

View File

@@ -72,6 +72,7 @@ from pydantic import (
)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.sampling_params import (
@@ -131,36 +132,6 @@ class ErrorResponse(OpenAIBaseModel):
error: ErrorInfo
class VLLMValidationError(ValueError):
"""vLLM-specific validation error for request validation failures.
Args:
message: The error message describing the validation failure.
parameter: Optional parameter name that failed validation.
value: Optional value that was rejected during validation.
"""
def __init__(
self,
message: str,
*,
parameter: str | None = None,
value: Any = None,
) -> None:
super().__init__(message)
self.parameter = parameter
self.value = value
def __str__(self):
base = super().__str__()
extras = []
if self.parameter is not None:
extras.append(f"parameter={self.parameter}")
if self.value is not None:
extras.append(f"value={self.value}")
return f"{base} ({', '.join(extras)})" if extras else base
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"

View File

@@ -140,16 +140,16 @@ class OpenAIServingCompletion(OpenAIServing):
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
return self.create_error_response(e)
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
return self.create_error_response(e)
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
return self.create_error_response(e)
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
return self.create_error_response(e)
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)

View File

@@ -754,7 +754,7 @@ class OpenAIServing:
if isinstance(message, Exception):
exc = message
from vllm.entrypoints.openai.protocol import VLLMValidationError
from vllm.exceptions import VLLMValidationError
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"

View File

@@ -373,7 +373,7 @@ class OpenAIServingResponses(OpenAIServing):
NotImplementedError,
) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
return self.create_error_response(e)
request_metadata = RequestResponseMetadata(request_id=request.request_id)
if raw_request:

View File

@@ -12,7 +12,7 @@ import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import VLLMValidationError
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike

36
vllm/exceptions.py Normal file
View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom exceptions for vLLM."""
from typing import Any
class VLLMValidationError(ValueError):
"""vLLM-specific validation error for request validation failures.
Args:
message: The error message describing the validation failure.
parameter: Optional parameter name that failed validation.
value: Optional value that was rejected during validation.
"""
def __init__(
self,
message: str,
*,
parameter: str | None = None,
value: Any = None,
) -> None:
super().__init__(message)
self.parameter = parameter
self.value = value
def __str__(self):
base = super().__str__()
extras = []
if self.parameter is not None:
extras.append(f"parameter={self.parameter}")
if self.value is not None:
extras.append(f"value={self.value}")
return f"{base} ({', '.join(extras)})" if extras else base

View File

@@ -11,6 +11,7 @@ from typing import Annotated, Any
import msgspec
from pydantic.dataclasses import dataclass
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.tokenizers import TokenizerLike
@@ -393,11 +394,17 @@ class SamplingParams(
f"{self.repetition_penalty}."
)
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}."
raise VLLMValidationError(
f"temperature must be non-negative, got {self.temperature}.",
parameter="temperature",
value=self.temperature,
)
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
raise VLLMValidationError(
f"top_p must be in (0, 1], got {self.top_p}.",
parameter="top_p",
value=self.top_p,
)
# quietly accept -1 as disabled, but prefer 0
if self.top_k < -1:
raise ValueError(
@@ -410,7 +417,11 @@ class SamplingParams(
if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
raise VLLMValidationError(
f"max_tokens must be at least 1, got {self.max_tokens}.",
parameter="max_tokens",
value=self.max_tokens,
)
if self.min_tokens < 0:
raise ValueError(
f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
@@ -421,24 +432,30 @@ class SamplingParams(
f"max_tokens={self.max_tokens}, got {self.min_tokens}."
)
if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
raise ValueError(
f"logprobs must be non-negative or -1, got {self.logprobs}."
raise VLLMValidationError(
f"logprobs must be non-negative or -1, got {self.logprobs}.",
parameter="logprobs",
value=self.logprobs,
)
if (
self.prompt_logprobs is not None
and self.prompt_logprobs != -1
and self.prompt_logprobs < 0
):
raise ValueError(
raise VLLMValidationError(
f"prompt_logprobs must be non-negative or -1, got "
f"{self.prompt_logprobs}."
f"{self.prompt_logprobs}.",
parameter="prompt_logprobs",
value=self.prompt_logprobs,
)
if self.truncate_prompt_tokens is not None and (
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
):
raise ValueError(
raise VLLMValidationError(
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
f"got {self.truncate_prompt_tokens}"
f"got {self.truncate_prompt_tokens}",
parameter="truncate_prompt_tokens",
value=self.truncate_prompt_tokens,
)
assert isinstance(self.stop_token_ids, list)
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
@@ -516,12 +533,14 @@ class SamplingParams(
if token_id < 0 or token_id > tokenizer.max_token_id
]
if len(invalid_token_ids) > 0:
raise ValueError(
raise VLLMValidationError(
f"The model vocabulary size is {tokenizer.max_token_id + 1},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id <= {tokenizer.max_token_id}."
f" 0 <= token_id <= {tokenizer.max_token_id}.",
parameter="bad_words",
value=self.bad_words,
)
@cached_property

View File

@@ -6,6 +6,7 @@ from collections.abc import Mapping
from typing import Any, Literal, cast
from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
@@ -83,9 +84,11 @@ class InputProcessor:
if num_logprobs == -1:
num_logprobs = self.model_config.get_vocab_size()
if num_logprobs > max_logprobs:
raise ValueError(
raise VLLMValidationError(
f"Requested sample logprobs of {num_logprobs}, "
f"which is greater than max allowed: {max_logprobs}"
f"which is greater than max allowed: {max_logprobs}",
parameter="logprobs",
value=num_logprobs,
)
# Validate prompt logprobs.
@@ -94,9 +97,11 @@ class InputProcessor:
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.model_config.get_vocab_size()
if num_prompt_logprobs > max_logprobs:
raise ValueError(
raise VLLMValidationError(
f"Requested prompt logprobs of {num_prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}"
f"which is greater than max allowed: {max_logprobs}",
parameter="prompt_logprobs",
value=num_prompt_logprobs,
)
def _validate_sampling_params(
@@ -134,9 +139,11 @@ class InputProcessor:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
raise VLLMValidationError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}"
f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
parameter="logit_bias",
value=invalid_token_ids,
)
def _validate_supported_sampling_params(