[Refactor] Move validation to params definitions (#34362)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-12 11:33:15 +08:00
committed by GitHub
parent e1d97c38f8
commit ced2a92f40
3 changed files with 264 additions and 245 deletions

View File

@@ -72,7 +72,7 @@ class PoolingParams(
"""Returns a deep copy of the PoolingParams instance."""
return deepcopy(self)
def verify(self, model_config: "ModelConfig") -> None:
def verify(self, model_config: ModelConfig) -> None:
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
@@ -87,12 +87,7 @@ class PoolingParams(
self._set_default_parameters(model_config)
self._verify_valid_parameters()
def _merge_default_parameters(
self, model_config: "ModelConfig | None" = None
) -> None:
if model_config is None:
return
def _merge_default_parameters(self, model_config: ModelConfig) -> None:
pooler_config = model_config.pooler_config
if pooler_config is None:
return
@@ -119,7 +114,9 @@ class PoolingParams(
self._verify_step_pooling(pooler_config, valid_parameters)
def _verify_step_pooling(
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
self,
pooler_config: PoolerConfig,
valid_parameters: list[str],
):
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
if pooler_config.tok_pooling_type != "STEP":
@@ -142,12 +139,12 @@ class PoolingParams(
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
def _set_default_parameters(self, model_config: "ModelConfig | None"):
def _set_default_parameters(self, model_config: ModelConfig):
if self.task in ["embed", "token_embed"]:
if self.use_activation is None:
self.use_activation = True
if self.dimensions is not None and model_config is not None:
if self.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
f'Model "{model_config.served_model_name}" does not '

View File

@@ -3,6 +3,7 @@
"""Sampling parameters for text generation."""
import copy
import json
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
@@ -11,6 +12,7 @@ from typing import Annotated, Any
import msgspec
from pydantic.dataclasses import dataclass
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
@@ -453,6 +455,11 @@ class SamplingParams(
parameter="prompt_logprobs",
value=self.prompt_logprobs,
)
if self.logits_processors:
# TODO: Remove `logits_processors` attribute
raise ValueError(
"vLLM V1 does not support per request user-provided logits processors."
)
if self.truncate_prompt_tokens is not None and (
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
):
@@ -589,6 +596,237 @@ class SamplingParams(
)
return copy.deepcopy(self, memo=logit_processor_refs)
def verify(
self,
model_config: ModelConfig,
speculative_config: SpeculativeConfig | None,
structured_outputs_config: StructuredOutputsConfig | None,
tokenizer: TokenizerLike | None,
) -> None:
self._validate_logprobs(model_config)
self._validate_logit_bias(model_config)
self._validate_allowed_token_ids(tokenizer)
self._validate_spec_decode(speculative_config)
self._validate_structured_outputs(structured_outputs_config, tokenizer)
def _validate_logprobs(self, model_config: ModelConfig) -> None:
max_logprobs = model_config.max_logprobs
if max_logprobs == -1:
max_logprobs = model_config.get_vocab_size()
# Validate sample logprobs.
if num_logprobs := self.logprobs:
if num_logprobs == -1:
num_logprobs = model_config.get_vocab_size()
if num_logprobs > max_logprobs:
raise VLLMValidationError(
f"Requested sample logprobs of {num_logprobs}, "
f"which is greater than max allowed: {max_logprobs}",
parameter="logprobs",
value=num_logprobs,
)
# Validate prompt logprobs.
if num_prompt_logprobs := self.prompt_logprobs:
if num_prompt_logprobs == -1:
num_prompt_logprobs = model_config.get_vocab_size()
if num_prompt_logprobs > max_logprobs:
raise VLLMValidationError(
f"Requested prompt logprobs of {num_prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}",
parameter="prompt_logprobs",
value=num_prompt_logprobs,
)
def _validate_logit_bias(self, model_config: ModelConfig) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not self.logit_bias:
return
vocab_size = model_config.get_vocab_size()
invalid_token_ids = [
token_id
for token_id in self.logit_bias
if token_id < 0 or token_id >= vocab_size
]
if invalid_token_ids:
raise VLLMValidationError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
parameter="logit_bias",
value=invalid_token_ids,
)
def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None:
allowed_token_ids = self.allowed_token_ids
if allowed_token_ids is None:
return
if len(allowed_token_ids) == 0:
raise VLLMValidationError(
"allowed_token_ids is not None and empty!",
parameter="allowed_token_ids",
value=allowed_token_ids,
)
if tokenizer is not None:
vocab_size = len(tokenizer)
invalid_token_ids = [
token_id
for token_id in allowed_token_ids
if token_id < 0 or token_id >= vocab_size
]
if invalid_token_ids:
raise VLLMValidationError(
"allowed_token_ids contains out-of-vocab token id!",
parameter="allowed_token_ids",
value=invalid_token_ids,
)
def _validate_spec_decode(
self,
speculative_config: SpeculativeConfig | None,
) -> None:
if speculative_config is None:
return
# Some sampling parameters are not yet compatible with spec decoding.
if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias:
raise ValueError(
"The min_tokens, min_p, and logit_bias sampling parameters "
"are not yet supported with speculative decoding."
)
def _validate_structured_outputs(
self,
structured_outputs_config: StructuredOutputsConfig | None,
tokenizer: TokenizerLike | None,
) -> None:
if structured_outputs_config is None or self.structured_outputs is None:
return
if tokenizer is None:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
backend = structured_outputs_config.backend
if _backend := self.structured_outputs._backend:
# Request-level backend selection is not supported.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_backend_was_auto` field set in the params.
if backend != _backend and not (
backend == "auto" and self.structured_outputs._backend_was_auto
):
raise ValueError(
"Request-level structured output backend selection is not "
f"supported. The request specified '{_backend}', but vLLM "
f"was initialised with '{backend}'. This error can be "
"resolved by removing '_backend' from the request."
)
else:
self.structured_outputs._backend = backend
# Request content validation
if (
isinstance(self.structured_outputs.choice, list)
and not self.structured_outputs.choice
):
# It is invalid for choice to be an empty list
raise ValueError(
f"Choice '{self.structured_outputs.choice}' cannot be an empty list" # noqa: E501
)
# Reject empty string grammar early to avoid engine-side crashes
if (
isinstance(self.structured_outputs.grammar, str)
and self.structured_outputs.grammar.strip() == ""
):
raise ValueError("structured_outputs.grammar cannot be an empty string")
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.v1.structured_output.backend_guidance import (
has_guidance_unsupported_json_features,
validate_guidance_grammar,
)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer,
)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines,
)
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
if backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(self)
elif backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar(self, tokenizer=None)
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(self)
elif backend == "lm-format-enforcer":
# lm format enforcer backend
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_structured_output_request_lm_format_enforcer(self)
else:
# NOTE: backend must be "auto" here, because we have
# checked supported_backends above.
# In this mode, we set opinionated defaults based on what we think
# will satisfy the most use cases without having to worry about
# this setting. We include fallback behavior here, but not with any
# other setting where a specific backend was specified.
try:
validate_xgrammar_grammar(self)
self.structured_outputs._backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
# Check if schema has features unsupported by guidance
so_params = self.structured_outputs
skip_guidance = False
if so_params.json:
if isinstance(so_params.json, str):
schema = json.loads(so_params.json)
else:
schema = so_params.json
skip_guidance = has_guidance_unsupported_json_features(schema)
if isinstance(tokenizer, MistralTokenizer) or skip_guidance:
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines(self)
self.structured_outputs._backend = "outlines"
else:
# Fall back to guidance by default.
validate_guidance_grammar(self, tokenizer=None)
self.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
self.structured_outputs._backend_was_auto = True
# Run post-init validation. This is also important to ensure subsequent
# roundtrip serialization/deserialization won't fail.
self.structured_outputs.__post_init__()
def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "

View File

@@ -6,7 +6,6 @@ from collections.abc import Mapping
from typing import Any, Literal, cast
from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
@@ -30,25 +29,13 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import (
has_guidance_unsupported_json_features,
validate_guidance_grammar,
)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer,
)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines,
)
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
logger = init_logger(__name__)
@@ -64,6 +51,7 @@ class InputProcessor:
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.structured_outputs_config = vllm_config.structured_outputs_config
self.observability_config = vllm_config.observability_config
@@ -101,101 +89,6 @@ class InputProcessor:
def renderer(self) -> BaseRenderer:
return self.input_preprocessor.renderer
def _validate_logprobs(
self,
params: SamplingParams,
) -> None:
max_logprobs = self.model_config.max_logprobs
if max_logprobs == -1:
max_logprobs = self.model_config.get_vocab_size()
# Validate sample logprobs.
if params.logprobs:
num_logprobs = params.logprobs
if num_logprobs == -1:
num_logprobs = self.model_config.get_vocab_size()
if num_logprobs > max_logprobs:
raise VLLMValidationError(
f"Requested sample logprobs of {num_logprobs}, "
f"which is greater than max allowed: {max_logprobs}",
parameter="logprobs",
value=num_logprobs,
)
# Validate prompt logprobs.
if params.prompt_logprobs:
num_prompt_logprobs = params.prompt_logprobs
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.model_config.get_vocab_size()
if num_prompt_logprobs > max_logprobs:
raise VLLMValidationError(
f"Requested prompt logprobs of {num_prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}",
parameter="prompt_logprobs",
value=num_prompt_logprobs,
)
def _validate_sampling_params(
self,
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
if self.tokenizer is None:
# When skip_tokenizer_init=True, we can't validate token IDs
# Skip validation and let the model handle invalid tokens
return
vocab_size = len(self.tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError("allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise VLLMValidationError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
parameter="logit_bias",
value=invalid_token_ids,
)
def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Logits processors not supported.
if params.logits_processors:
raise ValueError(
"vLLM V1 does not support per request user-provided logits processors."
)
# Some sampling parameters are not yet compatible with spec decoding.
if self.vllm_config.speculative_config is not None and (
params.min_tokens > 1 or params.min_p > _SAMPLING_EPS or params.logit_bias
):
raise ValueError(
"The min_tokens, min_p, and logit_bias sampling parameters "
"are not yet supported with speculative decoding."
)
def _validate_params(
self,
params: SamplingParams | PoolingParams,
@@ -203,11 +96,15 @@ class InputProcessor:
# is passed to all `process_inputs` calls
supported_tasks: tuple[SupportedTask, ...] | None,
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if isinstance(params, PoolingParams):
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
if isinstance(params, SamplingParams):
params.verify(
self.model_config,
self.speculative_config,
self.structured_outputs_config,
self.tokenizer,
)
elif isinstance(params, PoolingParams):
if supported_tasks is None:
raise RuntimeError("`supported_tasks` must be passed for pooling")
@@ -233,12 +130,11 @@ class InputProcessor:
)
params.verify(self.model_config)
return
self._validate_logprobs(params)
self._validate_sampling_params(params)
self._validate_supported_sampling_params(params)
else:
raise TypeError(
f"params must be either SamplingParams or PoolingParams, "
f"but got {type(params).__name__}"
)
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.input_preprocessor._get_mm_processor()
@@ -334,120 +230,6 @@ class InputProcessor:
"[lora_path]` to use the LoRA tokenizer."
)
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.structured_outputs or not self.structured_outputs_config:
return
if self.model_config.skip_tokenizer_init and params.structured_outputs:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
backend = self.structured_outputs_config.backend
if _backend := params.structured_outputs._backend:
# Request-level backend selection is not supported.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_backend_was_auto` field set in the params.
if backend != _backend and not (
backend == "auto" and params.structured_outputs._backend_was_auto
):
raise ValueError(
"Request-level structured output backend selection is not "
f"supported. The request specified '{_backend}', but vLLM "
f"was initialised with '{backend}'. This error can be "
"resolved by removing '_backend' from the request."
)
else:
params.structured_outputs._backend = backend
# Request content validation
if (
isinstance(params.structured_outputs.choice, list)
and not params.structured_outputs.choice
):
# It is invalid for choice to be an empty list
raise ValueError(
f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501
)
# Reject empty string grammar early to avoid engine-side crashes
if (
isinstance(params.structured_outputs.grammar, str)
and params.structured_outputs.grammar.strip() == ""
):
raise ValueError("structured_outputs.grammar cannot be an empty string")
if backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
elif backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if isinstance(self.tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar(params, tokenizer=None)
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
elif backend == "lm-format-enforcer":
# lm format enforcer backend
if isinstance(self.tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_structured_output_request_lm_format_enforcer(params)
else:
# NOTE: backend must be "auto" here, because we have
# checked supported_backends above.
# In this mode, we set opinionated defaults based on what we think
# will satisfy the most use cases without having to worry about
# this setting. We include fallback behavior here, but not with any
# other setting where a specific backend was specified.
try:
validate_xgrammar_grammar(params)
params.structured_outputs._backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
# Check if schema has features unsupported by guidance
so_params = params.structured_outputs
skip_guidance = False
if so_params.json:
if isinstance(so_params.json, str):
import json
schema = json.loads(so_params.json)
else:
schema = so_params.json
skip_guidance = has_guidance_unsupported_json_features(schema)
if isinstance(self.tokenizer, MistralTokenizer) or skip_guidance:
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines(params)
params.structured_outputs._backend = "outlines"
else:
# Fall back to guidance by default.
validate_guidance_grammar(params, tokenizer=None)
params.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
params.structured_outputs._backend_was_auto = True
# Run post-init validation. This is also important to ensure subsequent
# roundtrip serialization/deserialization won't fail.
params.structured_outputs.__post_init__()
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
@@ -618,8 +400,10 @@ class InputProcessor:
prompt_token_ids, prompt_embeds
)
sampling_params.max_tokens = self.model_config.max_model_len - seq_len
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id
self.generation_config_fields,
None if self.tokenizer is None else self.tokenizer.eos_token_id,
)
if self.tokenizer is not None:
sampling_params.update_from_tokenizer(self.tokenizer)