Files
vllm/vllm/sampling_params.py
2026-02-12 18:18:24 -08:00

858 lines
35 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampling parameters for text generation."""
import copy
import json
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
from typing import Annotated, Any
import msgspec
from pydantic.dataclasses import dataclass
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__)
_SAMPLING_EPS = 1e-5
_MAX_TEMP = 1e-2
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
RANDOM_SEED = 2
# maybe make msgspec?
@dataclass
class StructuredOutputsParams:
# One of these fields will be used to build a logit processor.
json: str | dict | None = None
regex: str | None = None
choice: list[str] | None = None
grammar: str | None = None
json_object: bool | None = None
# These are other options that can be set.
disable_fallback: bool = False
disable_any_whitespace: bool = False
disable_additional_properties: bool = False
whitespace_pattern: str | None = None
structural_tag: str | None = None
_backend: str | None = field(default=None, init=False)
"""CAUTION: Should only be set by Processor._validate_structured_output"""
_backend_was_auto: bool = field(default=False, init=False)
"""CAUTION: Should only be set by Processor._validate_structured_output"""
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
count = sum(
[
self.json is not None,
self.regex is not None,
self.choice is not None,
self.grammar is not None,
self.json_object is not None,
self.structural_tag is not None,
]
)
if count > 1:
raise ValueError(
"You can only use one kind of structured outputs constraint "
f"but multiple are specified: {self.__dict__}"
)
if count < 1:
raise ValueError(
"You must use one kind of structured outputs constraint "
f"but none are specified: {self.__dict__}"
)
def all_constraints_none(self) -> bool:
"""
Returns True if all structured-output constraint fields are None.
"""
return all(
getattr(self, field) is None
for field in (
"json",
"regex",
"choice",
"grammar",
"json_object",
"structural_tag",
)
)
def all_non_structural_tag_constraints_none(self) -> bool:
"""
Returns True if all structured-output constraint fields are None.
"""
return all(
getattr(self, field) is None
for field in (
"json",
"regex",
"choice",
"grammar",
"json_object",
)
)
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOutput
FINAL_ONLY = 2
class SamplingParams(
PydanticMsgspecMixin,
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
): # type: ignore[call-arg]
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
API (https://platform.openai.com/docs/api-reference/completions/create).
In addition, we support beam search, which is not supported by OpenAI.
"""
n: int = 1
"""Number of outputs to return for the given prompt request.
NOTE:
`AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
are generated and streamed cumulatively per request. To see all `n`
outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
in `SamplingParams`."""
presence_penalty: float = 0.0
"""Penalizes new tokens based on whether they appear in the generated text
so far. Values > 0 encourage the model to use new tokens, while values < 0
encourage the model to repeat tokens."""
frequency_penalty: float = 0.0
"""Penalizes new tokens based on their frequency in the generated text so
far. Values > 0 encourage the model to use new tokens, while values < 0
encourage the model to repeat tokens."""
repetition_penalty: float = 1.0
"""Penalizes new tokens based on whether they appear in the prompt and the
generated text so far. Values > 1 encourage the model to use new tokens,
while values < 1 encourage the model to repeat tokens."""
temperature: float = 1.0
"""Controls the randomness of the sampling. Lower values make the model
more deterministic, while higher values make the model more random. Zero
means greedy sampling."""
top_p: float = 1.0
"""Controls the cumulative probability of the top tokens to consider. Must
be in (0, 1]. Set to 1 to consider all tokens."""
top_k: int = 0
"""Controls the number of top tokens to consider. Set to 0 (or -1) to
consider all tokens."""
min_p: float = 0.0
"""Represents the minimum probability for a token to be considered,
relative to the probability of the most likely token. Must be in [0, 1].
Set to 0 to disable this."""
seed: int | None = None
"""Random seed to use for the generation."""
stop: str | list[str] | None = None
"""String(s) that stop the generation when they are generated. The returned
output will not contain the stop strings."""
stop_token_ids: list[int] | None = None
"""Token IDs that stop the generation when they are generated. The returned
output will contain the stop tokens unless the stop tokens are special
tokens."""
ignore_eos: bool = False
"""Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated."""
max_tokens: int | None = 16
"""Maximum number of tokens to generate per output sequence."""
min_tokens: int = 0
"""Minimum number of tokens to generate per output sequence before EOS or
`stop_token_ids` can be generated"""
logprobs: int | None = None
"""Number of log probabilities to return per output token. When set to
`None`, no probability is returned. If set to a non-`None` value, the
result includes the log probabilities of the specified number of most
likely tokens, as well as the chosen tokens. Note that the implementation
follows the OpenAI API: The API will always return the log probability of
the sampled token, so there may be up to `logprobs+1` elements in the
response. When set to -1, return all `vocab_size` log probabilities."""
prompt_logprobs: int | None = None
"""Number of log probabilities to return per prompt token.
When set to -1, return all `vocab_size` log probabilities."""
flat_logprobs: bool = False
"""Whether to return logprobs in flatten format (i.e. FlatLogprob)
for better performance.
NOTE: GC costs of FlatLogprobs is significantly smaller than
list[dict[int, Logprob]]. After enabled, PromptLogprobs and
SampleLogprobs would populated as FlatLogprobs."""
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
detokenize: bool = True
"""Whether to detokenize the output."""
skip_special_tokens: bool = True
"""Whether to skip special tokens in the output."""
spaces_between_special_tokens: bool = True
"""Whether to add spaces between special tokens in the output."""
include_stop_str_in_output: bool = False
"""Whether to include the stop strings in output text."""
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
"""If set to -1, will use the truncation size supported by the model. If
set to an integer k, will use only the last k tokens from the prompt
(i.e., left truncation). If set to `None`, truncation is disabled."""
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
skip_clone: bool = False
"""Internal flag indicating that this SamplingParams instance is safe to
reuse without cloning. When True, clone() will return self without
performing a deep copy. This should only be set when the params object
is guaranteed to be dedicated to a single request and won't be modified
in ways that would affect other uses."""
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
_eos_token_id: int | None = None
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors
structured_outputs: StructuredOutputsParams | None = None
"""Parameters for configuring structured outputs."""
logit_bias: dict[int, float] | None = None
"""If provided, the engine will construct a logits processor that applies
these logit biases."""
allowed_token_ids: list[int] | None = None
"""If provided, the engine will construct a logits processor which only
retains scores for the given token ids."""
extra_args: dict[str, Any] | None = None
"""Arbitrary additional args, that can be used by custom sampling
implementations, plugins, etc. Not used by any in-tree sampling
implementations."""
# Fields used for bad words
bad_words: list[str] | None = None
"""Words that are not allowed to be generated. More precisely, only the
last token of a corresponding token sequence is not allowed when the next
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None
skip_reading_prefix_cache: bool | None = None
@staticmethod
def from_optional(
n: int | None = 1,
presence_penalty: float | None = 0.0,
frequency_penalty: float | None = 0.0,
repetition_penalty: float | None = 1.0,
temperature: float | None = 1.0,
top_p: float | None = 1.0,
top_k: int = 0,
min_p: float = 0.0,
seed: int | None = None,
stop: str | list[str] | None = None,
stop_token_ids: list[int] | None = None,
bad_words: list[str] | None = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: int | None = 16,
min_tokens: int = 0,
logprobs: int | None = None,
prompt_logprobs: int | None = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: StructuredOutputsParams | None = None,
logit_bias: dict[int, float] | dict[str, float] | None = None,
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
skip_clone: bool = False,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias = {
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}
return SamplingParams(
n=1 if n is None else n,
presence_penalty=0.0 if presence_penalty is None else presence_penalty,
frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
repetition_penalty=1.0
if repetition_penalty is None
else repetition_penalty,
temperature=1.0 if temperature is None else temperature,
top_p=1.0 if top_p is None else top_p,
top_k=top_k,
min_p=min_p,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
bad_words=bad_words,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
detokenize=detokenize,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
structured_outputs=structured_outputs,
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
extra_args=extra_args,
skip_clone=skip_clone,
)
def __post_init__(self) -> None:
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.",
self.temperature,
_MAX_TEMP,
_MAX_TEMP,
)
self.temperature = max(self.temperature, _MAX_TEMP)
if self.seed == -1:
self.seed = None
if self.stop is None:
self.stop = []
elif isinstance(self.stop, str):
self.stop = [self.stop]
if self.stop_token_ids is None:
self.stop_token_ids = []
if self.bad_words is None:
self.bad_words = []
if self.logprobs is True:
self.logprobs = 1
if self.prompt_logprobs is True:
self.prompt_logprobs = 1
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if self.stop and not self.include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
self._verify_args()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self.top_p = 1.0
self.top_k = 0
self.min_p = 0.0
self._verify_greedy_sampling()
# eos_token_id is added to this by the engine
self._all_stop_token_ids.update(self.stop_token_ids)
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of prompt logprobs may less than n_prompt_tokens,
# we need to skip reading cache at this request.
self.skip_reading_prefix_cache = self.prompt_logprobs is not None
def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError(
f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
)
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError(
f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
)
if self.repetition_penalty <= 0.0:
raise ValueError(
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}."
)
if self.temperature < 0.0:
raise VLLMValidationError(
f"temperature must be non-negative, got {self.temperature}.",
parameter="temperature",
value=self.temperature,
)
if not 0.0 < self.top_p <= 1.0:
raise VLLMValidationError(
f"top_p must be in (0, 1], got {self.top_p}.",
parameter="top_p",
value=self.top_p,
)
# quietly accept -1 as disabled, but prefer 0
if self.top_k < -1:
raise ValueError(
f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
)
if not isinstance(self.top_k, int):
raise TypeError(
f"top_k must be an integer, got {type(self.top_k).__name__}"
)
if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.max_tokens is not None and self.max_tokens < 1:
raise VLLMValidationError(
f"max_tokens must be at least 1, got {self.max_tokens}.",
parameter="max_tokens",
value=self.max_tokens,
)
if self.min_tokens < 0:
raise ValueError(
f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
)
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}."
)
if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
raise VLLMValidationError(
f"logprobs must be non-negative or -1, got {self.logprobs}.",
parameter="logprobs",
value=self.logprobs,
)
if (
self.prompt_logprobs is not None
and self.prompt_logprobs != -1
and self.prompt_logprobs < 0
):
raise VLLMValidationError(
f"prompt_logprobs must be non-negative or -1, got "
f"{self.prompt_logprobs}.",
parameter="prompt_logprobs",
value=self.prompt_logprobs,
)
if self.truncate_prompt_tokens is not None and (
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
):
raise VLLMValidationError(
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
f"got {self.truncate_prompt_tokens}",
parameter="truncate_prompt_tokens",
value=self.truncate_prompt_tokens,
)
assert isinstance(self.stop_token_ids, list)
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
raise ValueError(
f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
)
assert isinstance(self.stop, list)
if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
)
def _verify_greedy_sampling(self) -> None:
if self.n > 1:
raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
def update_from_generation_config(
self,
generation_config: dict[str, Any],
eos_token_id: int | None = None,
) -> None:
"""Update if there are non-default values from generation_config"""
if not self.ignore_eos:
self._eos_token_id = eos_token_id
if eos_token_id is not None:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self._all_stop_token_ids.add(eos_token_id)
# Update eos_token_id for generation
if (eos_ids := generation_config.get("eos_token_id")) is not None:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if eos_token_id is not None:
# We don't need to include the primary eos_token_id in
# stop_token_ids since it's handled separately for stopping
# purposes.
eos_ids.discard(eos_token_id)
if eos_ids:
self._all_stop_token_ids.update(eos_ids)
if not self.ignore_eos:
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)
def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
if not self.bad_words:
return
self._bad_words_token_ids = []
for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = tokenizer.encode(
text=prompt, add_special_tokens=False
)
# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space
and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
):
self._bad_words_token_ids.append(prompt_token_ids)
invalid_token_ids = [
token_id
for bad_words_token_ids in self._bad_words_token_ids
for token_id in bad_words_token_ids
if token_id < 0 or token_id > tokenizer.max_token_id
]
if len(invalid_token_ids) > 0:
raise VLLMValidationError(
f"The model vocabulary size is {tokenizer.max_token_id + 1},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id <= {tokenizer.max_token_id}.",
parameter="bad_words",
value=self.bad_words,
)
@cached_property
def sampling_type(self) -> SamplingType:
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
if self.seed is not None:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
@property
def eos_token_id(self) -> int | None:
return self._eos_token_id
@property
def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids
@property
def bad_words_token_ids(self) -> list[list[int]] | None:
# For internal use only. Backward compatibility not guaranteed
return self._bad_words_token_ids
def clone(self) -> "SamplingParams":
"""If skip_clone is True, uses shallow copy instead of deep copy."""
if self.skip_clone:
return copy.copy(self)
return copy.deepcopy(self)
def verify(
self,
model_config: ModelConfig,
speculative_config: SpeculativeConfig | None,
structured_outputs_config: StructuredOutputsConfig | None,
tokenizer: TokenizerLike | None,
) -> None:
self._validate_logprobs(model_config)
self._validate_logit_bias(model_config)
self._validate_logits_processors(model_config)
self._validate_allowed_token_ids(tokenizer)
self._validate_spec_decode(speculative_config)
self._validate_structured_outputs(structured_outputs_config, tokenizer)
def _validate_logprobs(self, model_config: ModelConfig) -> None:
max_logprobs = model_config.max_logprobs
if max_logprobs == -1:
max_logprobs = model_config.get_vocab_size()
# Validate sample logprobs.
if num_logprobs := self.logprobs:
if num_logprobs == -1:
num_logprobs = model_config.get_vocab_size()
if num_logprobs > max_logprobs:
raise VLLMValidationError(
f"Requested sample logprobs of {num_logprobs}, "
f"which is greater than max allowed: {max_logprobs}",
parameter="logprobs",
value=num_logprobs,
)
# Validate prompt logprobs.
if num_prompt_logprobs := self.prompt_logprobs:
if num_prompt_logprobs == -1:
num_prompt_logprobs = model_config.get_vocab_size()
if num_prompt_logprobs > max_logprobs:
raise VLLMValidationError(
f"Requested prompt logprobs of {num_prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}",
parameter="prompt_logprobs",
value=num_prompt_logprobs,
)
def _validate_logit_bias(self, model_config: ModelConfig) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not self.logit_bias:
return
vocab_size = model_config.get_vocab_size()
invalid_token_ids = [
token_id
for token_id in self.logit_bias
if token_id < 0 or token_id >= vocab_size
]
if invalid_token_ids:
raise VLLMValidationError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
parameter="logit_bias",
value=invalid_token_ids,
)
def _validate_logits_processors(self, model_config: ModelConfig) -> None:
from vllm.v1.sample.logits_processor import (
validate_logits_processors_parameters,
)
validate_logits_processors_parameters(model_config.logits_processors, self)
def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None:
allowed_token_ids = self.allowed_token_ids
if allowed_token_ids is None:
return
if len(allowed_token_ids) == 0:
raise VLLMValidationError(
"allowed_token_ids is not None and empty!",
parameter="allowed_token_ids",
value=allowed_token_ids,
)
if tokenizer is not None:
vocab_size = len(tokenizer)
invalid_token_ids = [
token_id
for token_id in allowed_token_ids
if token_id < 0 or token_id >= vocab_size
]
if invalid_token_ids:
raise VLLMValidationError(
"allowed_token_ids contains out-of-vocab token id!",
parameter="allowed_token_ids",
value=invalid_token_ids,
)
def _validate_spec_decode(
self,
speculative_config: SpeculativeConfig | None,
) -> None:
if speculative_config is None:
return
# Some sampling parameters are not yet compatible with spec decoding.
if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias:
raise ValueError(
"The min_tokens, min_p, and logit_bias sampling parameters "
"are not yet supported with speculative decoding."
)
def _validate_structured_outputs(
self,
structured_outputs_config: StructuredOutputsConfig | None,
tokenizer: TokenizerLike | None,
) -> None:
if structured_outputs_config is None or self.structured_outputs is None:
return
if tokenizer is None:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
backend = structured_outputs_config.backend
if _backend := self.structured_outputs._backend:
# Request-level backend selection is not supported.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_backend_was_auto` field set in the params.
if backend != _backend and not (
backend == "auto" and self.structured_outputs._backend_was_auto
):
raise ValueError(
"Request-level structured output backend selection is not "
f"supported. The request specified '{_backend}', but vLLM "
f"was initialised with '{backend}'. This error can be "
"resolved by removing '_backend' from the request."
)
else:
self.structured_outputs._backend = backend
# Request content validation
if (
isinstance(self.structured_outputs.choice, list)
and not self.structured_outputs.choice
):
# It is invalid for choice to be an empty list
raise ValueError(
f"Choice '{self.structured_outputs.choice}' cannot be an empty list" # noqa: E501
)
# Reject empty string grammar early to avoid engine-side crashes
if (
isinstance(self.structured_outputs.grammar, str)
and self.structured_outputs.grammar.strip() == ""
):
raise ValueError("structured_outputs.grammar cannot be an empty string")
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.v1.structured_output.backend_guidance import (
has_guidance_unsupported_json_features,
validate_guidance_grammar,
)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer,
)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines,
)
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
if backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(self)
elif backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar(self, tokenizer=None)
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(self)
elif backend == "lm-format-enforcer":
# lm format enforcer backend
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_structured_output_request_lm_format_enforcer(self)
else:
# NOTE: backend must be "auto" here, because we have
# checked supported_backends above.
# In this mode, we set opinionated defaults based on what we think
# will satisfy the most use cases without having to worry about
# this setting. We include fallback behavior here, but not with any
# other setting where a specific backend was specified.
try:
validate_xgrammar_grammar(self)
self.structured_outputs._backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
# Check if schema has features unsupported by guidance
so_params = self.structured_outputs
skip_guidance = False
if so_params.json:
if isinstance(so_params.json, str):
schema = json.loads(so_params.json)
else:
schema = so_params.json
skip_guidance = has_guidance_unsupported_json_features(schema)
if isinstance(tokenizer, MistralTokenizer) or skip_guidance:
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines(self)
self.structured_outputs._backend = "outlines"
else:
# Fall back to guidance by default.
validate_guidance_grammar(self, tokenizer=None)
self.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
self.structured_outputs._backend_was_auto = True
# Run post-init validation. This is also important to ensure subsequent
# roundtrip serialization/deserialization won't fail.
self.structured_outputs.__post_init__()
def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"bad_words={self.bad_words}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"min_tokens={self.min_tokens}, "
f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
f"structured_outputs={self.structured_outputs}, "
f"extra_args={self.extra_args})"
)
class BeamSearchParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
): # type: ignore[call-arg]
"""Beam search parameters for text generation."""
beam_width: int
max_tokens: int
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False