[Refactor] Move validation to params definitions (#34362)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -72,7 +72,7 @@ class PoolingParams(
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return deepcopy(self)
|
||||
|
||||
def verify(self, model_config: "ModelConfig") -> None:
|
||||
def verify(self, model_config: ModelConfig) -> None:
|
||||
# plugin task uses io_processor.parse_request to verify inputs,
|
||||
# skipping PoolingParams verify
|
||||
if self.task == "plugin":
|
||||
@@ -87,12 +87,7 @@ class PoolingParams(
|
||||
self._set_default_parameters(model_config)
|
||||
self._verify_valid_parameters()
|
||||
|
||||
def _merge_default_parameters(
|
||||
self, model_config: "ModelConfig | None" = None
|
||||
) -> None:
|
||||
if model_config is None:
|
||||
return
|
||||
|
||||
def _merge_default_parameters(self, model_config: ModelConfig) -> None:
|
||||
pooler_config = model_config.pooler_config
|
||||
if pooler_config is None:
|
||||
return
|
||||
@@ -119,7 +114,9 @@ class PoolingParams(
|
||||
self._verify_step_pooling(pooler_config, valid_parameters)
|
||||
|
||||
def _verify_step_pooling(
|
||||
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
|
||||
self,
|
||||
pooler_config: PoolerConfig,
|
||||
valid_parameters: list[str],
|
||||
):
|
||||
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||
if pooler_config.tok_pooling_type != "STEP":
|
||||
@@ -142,12 +139,12 @@ class PoolingParams(
|
||||
if getattr(self, k, None) is None:
|
||||
setattr(self, k, getattr(pooler_config, k))
|
||||
|
||||
def _set_default_parameters(self, model_config: "ModelConfig | None"):
|
||||
def _set_default_parameters(self, model_config: ModelConfig):
|
||||
if self.task in ["embed", "token_embed"]:
|
||||
if self.use_activation is None:
|
||||
self.use_activation = True
|
||||
|
||||
if self.dimensions is not None and model_config is not None:
|
||||
if self.dimensions is not None:
|
||||
if not model_config.is_matryoshka:
|
||||
raise ValueError(
|
||||
f'Model "{model_config.served_model_name}" does not '
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import field
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
@@ -11,6 +12,7 @@ from typing import Annotated, Any
|
||||
import msgspec
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
@@ -453,6 +455,11 @@ class SamplingParams(
|
||||
parameter="prompt_logprobs",
|
||||
value=self.prompt_logprobs,
|
||||
)
|
||||
if self.logits_processors:
|
||||
# TODO: Remove `logits_processors` attribute
|
||||
raise ValueError(
|
||||
"vLLM V1 does not support per request user-provided logits processors."
|
||||
)
|
||||
if self.truncate_prompt_tokens is not None and (
|
||||
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
|
||||
):
|
||||
@@ -589,6 +596,237 @@ class SamplingParams(
|
||||
)
|
||||
return copy.deepcopy(self, memo=logit_processor_refs)
|
||||
|
||||
def verify(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
speculative_config: SpeculativeConfig | None,
|
||||
structured_outputs_config: StructuredOutputsConfig | None,
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> None:
|
||||
self._validate_logprobs(model_config)
|
||||
self._validate_logit_bias(model_config)
|
||||
self._validate_allowed_token_ids(tokenizer)
|
||||
self._validate_spec_decode(speculative_config)
|
||||
self._validate_structured_outputs(structured_outputs_config, tokenizer)
|
||||
|
||||
def _validate_logprobs(self, model_config: ModelConfig) -> None:
|
||||
max_logprobs = model_config.max_logprobs
|
||||
if max_logprobs == -1:
|
||||
max_logprobs = model_config.get_vocab_size()
|
||||
|
||||
# Validate sample logprobs.
|
||||
if num_logprobs := self.logprobs:
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = model_config.get_vocab_size()
|
||||
if num_logprobs > max_logprobs:
|
||||
raise VLLMValidationError(
|
||||
f"Requested sample logprobs of {num_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}",
|
||||
parameter="logprobs",
|
||||
value=num_logprobs,
|
||||
)
|
||||
|
||||
# Validate prompt logprobs.
|
||||
if num_prompt_logprobs := self.prompt_logprobs:
|
||||
if num_prompt_logprobs == -1:
|
||||
num_prompt_logprobs = model_config.get_vocab_size()
|
||||
if num_prompt_logprobs > max_logprobs:
|
||||
raise VLLMValidationError(
|
||||
f"Requested prompt logprobs of {num_prompt_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}",
|
||||
parameter="prompt_logprobs",
|
||||
value=num_prompt_logprobs,
|
||||
)
|
||||
|
||||
def _validate_logit_bias(self, model_config: ModelConfig) -> None:
|
||||
"""Validate logit_bias token IDs are within vocabulary range."""
|
||||
if not self.logit_bias:
|
||||
return
|
||||
|
||||
vocab_size = model_config.get_vocab_size()
|
||||
invalid_token_ids = [
|
||||
token_id
|
||||
for token_id in self.logit_bias
|
||||
if token_id < 0 or token_id >= vocab_size
|
||||
]
|
||||
|
||||
if invalid_token_ids:
|
||||
raise VLLMValidationError(
|
||||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
|
||||
parameter="logit_bias",
|
||||
value=invalid_token_ids,
|
||||
)
|
||||
|
||||
def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None:
|
||||
allowed_token_ids = self.allowed_token_ids
|
||||
if allowed_token_ids is None:
|
||||
return
|
||||
|
||||
if len(allowed_token_ids) == 0:
|
||||
raise VLLMValidationError(
|
||||
"allowed_token_ids is not None and empty!",
|
||||
parameter="allowed_token_ids",
|
||||
value=allowed_token_ids,
|
||||
)
|
||||
|
||||
if tokenizer is not None:
|
||||
vocab_size = len(tokenizer)
|
||||
invalid_token_ids = [
|
||||
token_id
|
||||
for token_id in allowed_token_ids
|
||||
if token_id < 0 or token_id >= vocab_size
|
||||
]
|
||||
if invalid_token_ids:
|
||||
raise VLLMValidationError(
|
||||
"allowed_token_ids contains out-of-vocab token id!",
|
||||
parameter="allowed_token_ids",
|
||||
value=invalid_token_ids,
|
||||
)
|
||||
|
||||
def _validate_spec_decode(
|
||||
self,
|
||||
speculative_config: SpeculativeConfig | None,
|
||||
) -> None:
|
||||
if speculative_config is None:
|
||||
return
|
||||
|
||||
# Some sampling parameters are not yet compatible with spec decoding.
|
||||
if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias:
|
||||
raise ValueError(
|
||||
"The min_tokens, min_p, and logit_bias sampling parameters "
|
||||
"are not yet supported with speculative decoding."
|
||||
)
|
||||
|
||||
def _validate_structured_outputs(
|
||||
self,
|
||||
structured_outputs_config: StructuredOutputsConfig | None,
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> None:
|
||||
if structured_outputs_config is None or self.structured_outputs is None:
|
||||
return
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
|
||||
)
|
||||
|
||||
backend = structured_outputs_config.backend
|
||||
if _backend := self.structured_outputs._backend:
|
||||
# Request-level backend selection is not supported.
|
||||
# The values may differ if `params` is reused and was set
|
||||
# to a specific backend based on `auto` behavior in a previous
|
||||
# request. We remember that it was set as a result of `auto`
|
||||
# using the `_backend_was_auto` field set in the params.
|
||||
if backend != _backend and not (
|
||||
backend == "auto" and self.structured_outputs._backend_was_auto
|
||||
):
|
||||
raise ValueError(
|
||||
"Request-level structured output backend selection is not "
|
||||
f"supported. The request specified '{_backend}', but vLLM "
|
||||
f"was initialised with '{backend}'. This error can be "
|
||||
"resolved by removing '_backend' from the request."
|
||||
)
|
||||
else:
|
||||
self.structured_outputs._backend = backend
|
||||
|
||||
# Request content validation
|
||||
if (
|
||||
isinstance(self.structured_outputs.choice, list)
|
||||
and not self.structured_outputs.choice
|
||||
):
|
||||
# It is invalid for choice to be an empty list
|
||||
raise ValueError(
|
||||
f"Choice '{self.structured_outputs.choice}' cannot be an empty list" # noqa: E501
|
||||
)
|
||||
# Reject empty string grammar early to avoid engine-side crashes
|
||||
if (
|
||||
isinstance(self.structured_outputs.grammar, str)
|
||||
and self.structured_outputs.grammar.strip() == ""
|
||||
):
|
||||
raise ValueError("structured_outputs.grammar cannot be an empty string")
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
has_guidance_unsupported_json_features,
|
||||
validate_guidance_grammar,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
validate_structured_output_request_lm_format_enforcer,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
validate_structured_output_request_outlines,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
|
||||
|
||||
if backend.startswith("xgrammar"):
|
||||
# xgrammar with no fallback
|
||||
validate_xgrammar_grammar(self)
|
||||
elif backend.startswith("guidance"):
|
||||
# TODO: ideally we would have the LLTokenizer here as Lark syntax
|
||||
# allows <|special_token|> and similar, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'guidance' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
"backends or tokenizer_mode='hf' instead."
|
||||
)
|
||||
validate_guidance_grammar(self, tokenizer=None)
|
||||
elif backend == "outlines":
|
||||
# outlines backend
|
||||
validate_structured_output_request_outlines(self)
|
||||
elif backend == "lm-format-enforcer":
|
||||
# lm format enforcer backend
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
"backends or tokenizer_mode='hf' instead."
|
||||
)
|
||||
validate_structured_output_request_lm_format_enforcer(self)
|
||||
else:
|
||||
# NOTE: backend must be "auto" here, because we have
|
||||
# checked supported_backends above.
|
||||
# In this mode, we set opinionated defaults based on what we think
|
||||
# will satisfy the most use cases without having to worry about
|
||||
# this setting. We include fallback behavior here, but not with any
|
||||
# other setting where a specific backend was specified.
|
||||
try:
|
||||
validate_xgrammar_grammar(self)
|
||||
self.structured_outputs._backend = "xgrammar"
|
||||
except ValueError:
|
||||
# The request either failed validation
|
||||
# or includes some jsonschema feature(s) that
|
||||
# are not supported in xgrammar.
|
||||
|
||||
# Check if schema has features unsupported by guidance
|
||||
so_params = self.structured_outputs
|
||||
skip_guidance = False
|
||||
if so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
schema = json.loads(so_params.json)
|
||||
else:
|
||||
schema = so_params.json
|
||||
skip_guidance = has_guidance_unsupported_json_features(schema)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer) or skip_guidance:
|
||||
# Fall back to outlines if the tokenizer is Mistral
|
||||
# or if schema contains features unsupported by guidance
|
||||
validate_structured_output_request_outlines(self)
|
||||
self.structured_outputs._backend = "outlines"
|
||||
else:
|
||||
# Fall back to guidance by default.
|
||||
validate_guidance_grammar(self, tokenizer=None)
|
||||
self.structured_outputs._backend = "guidance"
|
||||
# Remember that this backend was set automatically
|
||||
self.structured_outputs._backend_was_auto = True
|
||||
|
||||
# Run post-init validation. This is also important to ensure subsequent
|
||||
# roundtrip serialization/deserialization won't fail.
|
||||
self.structured_outputs.__post_init__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"SamplingParams(n={self.n}, "
|
||||
|
||||
@@ -6,7 +6,6 @@ from collections.abc import Mapping
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import (
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
@@ -30,25 +29,13 @@ from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
has_guidance_unsupported_json_features,
|
||||
validate_guidance_grammar,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
validate_structured_output_request_lm_format_enforcer,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
validate_structured_output_request_outlines,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -64,6 +51,7 @@ class InputProcessor:
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.structured_outputs_config = vllm_config.structured_outputs_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
@@ -101,101 +89,6 @@ class InputProcessor:
|
||||
def renderer(self) -> BaseRenderer:
|
||||
return self.input_preprocessor.renderer
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
max_logprobs = self.model_config.max_logprobs
|
||||
if max_logprobs == -1:
|
||||
max_logprobs = self.model_config.get_vocab_size()
|
||||
|
||||
# Validate sample logprobs.
|
||||
if params.logprobs:
|
||||
num_logprobs = params.logprobs
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = self.model_config.get_vocab_size()
|
||||
if num_logprobs > max_logprobs:
|
||||
raise VLLMValidationError(
|
||||
f"Requested sample logprobs of {num_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}",
|
||||
parameter="logprobs",
|
||||
value=num_logprobs,
|
||||
)
|
||||
|
||||
# Validate prompt logprobs.
|
||||
if params.prompt_logprobs:
|
||||
num_prompt_logprobs = params.prompt_logprobs
|
||||
if num_prompt_logprobs == -1:
|
||||
num_prompt_logprobs = self.model_config.get_vocab_size()
|
||||
if num_prompt_logprobs > max_logprobs:
|
||||
raise VLLMValidationError(
|
||||
f"Requested prompt logprobs of {num_prompt_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}",
|
||||
parameter="prompt_logprobs",
|
||||
value=num_prompt_logprobs,
|
||||
)
|
||||
|
||||
def _validate_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
self._validate_logit_bias(params)
|
||||
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
if not params.allowed_token_ids:
|
||||
raise ValueError("allowed_token_ids is not None and empty!")
|
||||
if self.tokenizer is None:
|
||||
# When skip_tokenizer_init=True, we can't validate token IDs
|
||||
# Skip validation and let the model handle invalid tokens
|
||||
return
|
||||
vocab_size = len(self.tokenizer)
|
||||
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
||||
raise ValueError("allowed_token_ids contains out-of-vocab token id!")
|
||||
|
||||
def _validate_logit_bias(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
"""Validate logit_bias token IDs are within vocabulary range."""
|
||||
if not params.logit_bias:
|
||||
return
|
||||
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
invalid_token_ids = []
|
||||
|
||||
for token_id in params.logit_bias:
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if invalid_token_ids:
|
||||
raise VLLMValidationError(
|
||||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
|
||||
parameter="logit_bias",
|
||||
value=invalid_token_ids,
|
||||
)
|
||||
|
||||
def _validate_supported_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
# Logits processors not supported.
|
||||
if params.logits_processors:
|
||||
raise ValueError(
|
||||
"vLLM V1 does not support per request user-provided logits processors."
|
||||
)
|
||||
|
||||
# Some sampling parameters are not yet compatible with spec decoding.
|
||||
if self.vllm_config.speculative_config is not None and (
|
||||
params.min_tokens > 1 or params.min_p > _SAMPLING_EPS or params.logit_bias
|
||||
):
|
||||
raise ValueError(
|
||||
"The min_tokens, min_p, and logit_bias sampling parameters "
|
||||
"are not yet supported with speculative decoding."
|
||||
)
|
||||
|
||||
def _validate_params(
|
||||
self,
|
||||
params: SamplingParams | PoolingParams,
|
||||
@@ -203,11 +96,15 @@ class InputProcessor:
|
||||
# is passed to all `process_inputs` calls
|
||||
supported_tasks: tuple[SupportedTask, ...] | None,
|
||||
):
|
||||
"""
|
||||
Validate supported SamplingParam.
|
||||
Should raise ValueError if unsupported for API Server.
|
||||
"""
|
||||
if isinstance(params, PoolingParams):
|
||||
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
|
||||
if isinstance(params, SamplingParams):
|
||||
params.verify(
|
||||
self.model_config,
|
||||
self.speculative_config,
|
||||
self.structured_outputs_config,
|
||||
self.tokenizer,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
if supported_tasks is None:
|
||||
raise RuntimeError("`supported_tasks` must be passed for pooling")
|
||||
|
||||
@@ -233,12 +130,11 @@ class InputProcessor:
|
||||
)
|
||||
|
||||
params.verify(self.model_config)
|
||||
|
||||
return
|
||||
|
||||
self._validate_logprobs(params)
|
||||
self._validate_sampling_params(params)
|
||||
self._validate_supported_sampling_params(params)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"params must be either SamplingParams or PoolingParams, "
|
||||
f"but got {type(params).__name__}"
|
||||
)
|
||||
|
||||
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
mm_processor = self.input_preprocessor._get_mm_processor()
|
||||
@@ -334,120 +230,6 @@ class InputProcessor:
|
||||
"[lora_path]` to use the LoRA tokenizer."
|
||||
)
|
||||
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.structured_outputs or not self.structured_outputs_config:
|
||||
return
|
||||
|
||||
if self.model_config.skip_tokenizer_init and params.structured_outputs:
|
||||
raise ValueError(
|
||||
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
|
||||
)
|
||||
|
||||
backend = self.structured_outputs_config.backend
|
||||
if _backend := params.structured_outputs._backend:
|
||||
# Request-level backend selection is not supported.
|
||||
# The values may differ if `params` is reused and was set
|
||||
# to a specific backend based on `auto` behavior in a previous
|
||||
# request. We remember that it was set as a result of `auto`
|
||||
# using the `_backend_was_auto` field set in the params.
|
||||
if backend != _backend and not (
|
||||
backend == "auto" and params.structured_outputs._backend_was_auto
|
||||
):
|
||||
raise ValueError(
|
||||
"Request-level structured output backend selection is not "
|
||||
f"supported. The request specified '{_backend}', but vLLM "
|
||||
f"was initialised with '{backend}'. This error can be "
|
||||
"resolved by removing '_backend' from the request."
|
||||
)
|
||||
else:
|
||||
params.structured_outputs._backend = backend
|
||||
|
||||
# Request content validation
|
||||
if (
|
||||
isinstance(params.structured_outputs.choice, list)
|
||||
and not params.structured_outputs.choice
|
||||
):
|
||||
# It is invalid for choice to be an empty list
|
||||
raise ValueError(
|
||||
f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501
|
||||
)
|
||||
# Reject empty string grammar early to avoid engine-side crashes
|
||||
if (
|
||||
isinstance(params.structured_outputs.grammar, str)
|
||||
and params.structured_outputs.grammar.strip() == ""
|
||||
):
|
||||
raise ValueError("structured_outputs.grammar cannot be an empty string")
|
||||
|
||||
if backend.startswith("xgrammar"):
|
||||
# xgrammar with no fallback
|
||||
validate_xgrammar_grammar(params)
|
||||
elif backend.startswith("guidance"):
|
||||
# TODO: ideally we would have the LLTokenizer here as Lark syntax
|
||||
# allows <|special_token|> and similar, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'guidance' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
"backends or tokenizer_mode='hf' instead."
|
||||
)
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
elif backend == "outlines":
|
||||
# outlines backend
|
||||
validate_structured_output_request_outlines(params)
|
||||
elif backend == "lm-format-enforcer":
|
||||
# lm format enforcer backend
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
"backends or tokenizer_mode='hf' instead."
|
||||
)
|
||||
validate_structured_output_request_lm_format_enforcer(params)
|
||||
else:
|
||||
# NOTE: backend must be "auto" here, because we have
|
||||
# checked supported_backends above.
|
||||
# In this mode, we set opinionated defaults based on what we think
|
||||
# will satisfy the most use cases without having to worry about
|
||||
# this setting. We include fallback behavior here, but not with any
|
||||
# other setting where a specific backend was specified.
|
||||
try:
|
||||
validate_xgrammar_grammar(params)
|
||||
params.structured_outputs._backend = "xgrammar"
|
||||
except ValueError:
|
||||
# The request either failed validation
|
||||
# or includes some jsonschema feature(s) that
|
||||
# are not supported in xgrammar.
|
||||
|
||||
# Check if schema has features unsupported by guidance
|
||||
so_params = params.structured_outputs
|
||||
skip_guidance = False
|
||||
if so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
import json
|
||||
|
||||
schema = json.loads(so_params.json)
|
||||
else:
|
||||
schema = so_params.json
|
||||
skip_guidance = has_guidance_unsupported_json_features(schema)
|
||||
|
||||
if isinstance(self.tokenizer, MistralTokenizer) or skip_guidance:
|
||||
# Fall back to outlines if the tokenizer is Mistral
|
||||
# or if schema contains features unsupported by guidance
|
||||
validate_structured_output_request_outlines(params)
|
||||
params.structured_outputs._backend = "outlines"
|
||||
else:
|
||||
# Fall back to guidance by default.
|
||||
validate_guidance_grammar(params, tokenizer=None)
|
||||
params.structured_outputs._backend = "guidance"
|
||||
# Remember that this backend was set automatically
|
||||
params.structured_outputs._backend_was_auto = True
|
||||
|
||||
# Run post-init validation. This is also important to ensure subsequent
|
||||
# roundtrip serialization/deserialization won't fail.
|
||||
params.structured_outputs.__post_init__()
|
||||
|
||||
def _extract_singleton_mm_data(
|
||||
self, prompt: SingletonPrompt
|
||||
) -> MultiModalDataDict | None:
|
||||
@@ -618,8 +400,10 @@ class InputProcessor:
|
||||
prompt_token_ids, prompt_embeds
|
||||
)
|
||||
sampling_params.max_tokens = self.model_config.max_model_len - seq_len
|
||||
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id
|
||||
self.generation_config_fields,
|
||||
None if self.tokenizer is None else self.tokenizer.eos_token_id,
|
||||
)
|
||||
if self.tokenizer is not None:
|
||||
sampling_params.update_from_tokenizer(self.tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user