Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import warnings
|
||||
from dataclasses import field
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
from typing import Annotated, Any
|
||||
|
||||
import msgspec
|
||||
from pydantic.dataclasses import dataclass
|
||||
@@ -32,19 +32,19 @@ class SamplingType(IntEnum):
|
||||
@dataclass
|
||||
class StructuredOutputsParams:
|
||||
# One of these fields will be used to build a logit processor.
|
||||
json: Optional[Union[str, dict]] = None
|
||||
regex: Optional[str] = None
|
||||
choice: Optional[list[str]] = None
|
||||
grammar: Optional[str] = None
|
||||
json_object: Optional[bool] = None
|
||||
json: str | dict | None = None
|
||||
regex: str | None = None
|
||||
choice: list[str] | None = None
|
||||
grammar: str | None = None
|
||||
json_object: bool | None = None
|
||||
# These are other options that can be set.
|
||||
disable_fallback: bool = False
|
||||
disable_any_whitespace: bool = False
|
||||
disable_additional_properties: bool = False
|
||||
whitespace_pattern: Optional[str] = None
|
||||
structural_tag: Optional[str] = None
|
||||
whitespace_pattern: str | None = None
|
||||
structural_tag: str | None = None
|
||||
|
||||
_backend: Optional[str] = field(default=None, init=False)
|
||||
_backend: str | None = field(default=None, init=False)
|
||||
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
||||
_backend_was_auto: bool = field(default=False, init=False)
|
||||
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
||||
@@ -110,12 +110,12 @@ class SamplingParams(
|
||||
are generated and streamed cumulatively per request. To see all `n`
|
||||
outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
|
||||
in `SamplingParams`."""
|
||||
best_of: Optional[int] = None
|
||||
best_of: int | None = None
|
||||
"""Number of output sequences that are generated from the prompt. From
|
||||
these `best_of` sequences, the top `n` sequences are returned. `best_of`
|
||||
must be greater than or equal to `n`. By default, `best_of` is set to `n`.
|
||||
Warning, this is only supported in V0."""
|
||||
_real_n: Optional[int] = None
|
||||
_real_n: int | None = None
|
||||
presence_penalty: float = 0.0
|
||||
"""Penalizes new tokens based on whether they appear in the generated text
|
||||
so far. Values > 0 encourage the model to use new tokens, while values < 0
|
||||
@@ -142,24 +142,24 @@ class SamplingParams(
|
||||
"""Represents the minimum probability for a token to be considered,
|
||||
relative to the probability of the most likely token. Must be in [0, 1].
|
||||
Set to 0 to disable this."""
|
||||
seed: Optional[int] = None
|
||||
seed: int | None = None
|
||||
"""Random seed to use for the generation."""
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
stop: str | list[str] | None = None
|
||||
"""String(s) that stop the generation when they are generated. The returned
|
||||
output will not contain the stop strings."""
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
stop_token_ids: list[int] | None = None
|
||||
"""Token IDs that stop the generation when they are generated. The returned
|
||||
output will contain the stop tokens unless the stop tokens are special
|
||||
tokens."""
|
||||
ignore_eos: bool = False
|
||||
"""Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated."""
|
||||
max_tokens: Optional[int] = 16
|
||||
max_tokens: int | None = 16
|
||||
"""Maximum number of tokens to generate per output sequence."""
|
||||
min_tokens: int = 0
|
||||
"""Minimum number of tokens to generate per output sequence before EOS or
|
||||
`stop_token_ids` can be generated"""
|
||||
logprobs: Optional[int] = None
|
||||
logprobs: int | None = None
|
||||
"""Number of log probabilities to return per output token. When set to
|
||||
`None`, no probability is returned. If set to a non-`None` value, the
|
||||
result includes the log probabilities of the specified number of most
|
||||
@@ -167,7 +167,7 @@ class SamplingParams(
|
||||
follows the OpenAI API: The API will always return the log probability of
|
||||
the sampled token, so there may be up to `logprobs+1` elements in the
|
||||
response. When set to -1, return all `vocab_size` log probabilities."""
|
||||
prompt_logprobs: Optional[int] = None
|
||||
prompt_logprobs: int | None = None
|
||||
"""Number of log probabilities to return per prompt token.
|
||||
When set to -1, return all `vocab_size` log probabilities."""
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
@@ -179,14 +179,14 @@ class SamplingParams(
|
||||
"""Whether to skip special tokens in the output."""
|
||||
spaces_between_special_tokens: bool = True
|
||||
"""Whether to add spaces between special tokens in the output."""
|
||||
# Optional[list[LogitsProcessor]] type. We use Any here because
|
||||
# Optional[list[LogitsProcessor]] type is not supported by msgspec.
|
||||
logits_processors: Optional[Any] = None
|
||||
# `list[LogitsProcessor] | None` type. We use Any here because
|
||||
# `list[LogitsProcessor] | None` type is not supported by msgspec.
|
||||
logits_processors: Any | None = None
|
||||
"""Functions that modify logits based on previously generated tokens, and
|
||||
optionally prompt tokens as a first argument."""
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None
|
||||
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
|
||||
"""If set to -1, will use the truncation size supported by the model. If
|
||||
set to an integer k, will use only the last k tokens from the prompt
|
||||
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||
@@ -198,60 +198,60 @@ class SamplingParams(
|
||||
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
# Fields used to construct logits processors
|
||||
structured_outputs: Optional[StructuredOutputsParams] = None
|
||||
structured_outputs: StructuredOutputsParams | None = None
|
||||
"""Parameters for configuring structured outputs."""
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||
guided_decoding: GuidedDecodingParams | None = None
|
||||
"""Deprecated alias for structured_outputs."""
|
||||
logit_bias: Optional[dict[int, float]] = None
|
||||
logit_bias: dict[int, float] | None = None
|
||||
"""If provided, the engine will construct a logits processor that applies
|
||||
these logit biases."""
|
||||
allowed_token_ids: Optional[list[int]] = None
|
||||
allowed_token_ids: list[int] | None = None
|
||||
"""If provided, the engine will construct a logits processor which only
|
||||
retains scores for the given token ids."""
|
||||
extra_args: Optional[dict[str, Any]] = None
|
||||
extra_args: dict[str, Any] | None = None
|
||||
"""Arbitrary additional args, that can be used by custom sampling
|
||||
implementations, plugins, etc. Not used by any in-tree sampling
|
||||
implementations."""
|
||||
|
||||
# Fields used for bad words
|
||||
bad_words: Optional[list[str]] = None
|
||||
bad_words: list[str] | None = None
|
||||
"""Words that are not allowed to be generated. More precisely, only the
|
||||
last token of a corresponding token sequence is not allowed when the next
|
||||
generated token can complete the sequence."""
|
||||
_bad_words_token_ids: Optional[list[list[int]]] = None
|
||||
_bad_words_token_ids: list[list[int]] | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: Optional[int] = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = 0.0,
|
||||
frequency_penalty: Optional[float] = 0.0,
|
||||
repetition_penalty: Optional[float] = 1.0,
|
||||
temperature: Optional[float] = 1.0,
|
||||
top_p: Optional[float] = 1.0,
|
||||
n: int | None = 1,
|
||||
best_of: int | None = None,
|
||||
presence_penalty: float | None = 0.0,
|
||||
frequency_penalty: float | None = 0.0,
|
||||
repetition_penalty: float | None = 1.0,
|
||||
temperature: float | None = 1.0,
|
||||
top_p: float | None = 1.0,
|
||||
top_k: int = 0,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, list[str]]] = None,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
bad_words: Optional[list[str]] = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
bad_words: list[str] | None = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
max_tokens: int | None = 16,
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
logprobs: int | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
detokenize: bool = True,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[list[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None,
|
||||
logits_processors: list[LogitsProcessor] | None = None,
|
||||
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
structured_outputs: Optional[StructuredOutputsParams] = None,
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
||||
allowed_token_ids: Optional[list[int]] = None,
|
||||
extra_args: Optional[dict[str, Any]] = None,
|
||||
structured_outputs: StructuredOutputsParams | None = None,
|
||||
guided_decoding: GuidedDecodingParams | None = None,
|
||||
logit_bias: dict[int, float] | dict[str, float] | None = None,
|
||||
allowed_token_ids: list[int] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> "SamplingParams":
|
||||
if logit_bias is not None:
|
||||
# Convert token_id to integer
|
||||
@@ -483,7 +483,7 @@ class SamplingParams(
|
||||
def update_from_generation_config(
|
||||
self,
|
||||
generation_config: dict[str, Any],
|
||||
model_eos_token_id: Optional[int] = None,
|
||||
model_eos_token_id: int | None = None,
|
||||
) -> None:
|
||||
"""Update if there are non-default values from generation_config"""
|
||||
|
||||
@@ -559,7 +559,7 @@ class SamplingParams(
|
||||
return self._all_stop_token_ids
|
||||
|
||||
@property
|
||||
def bad_words_token_ids(self) -> Optional[list[list[int]]]:
|
||||
def bad_words_token_ids(self) -> list[list[int]] | None:
|
||||
# For internal use only. Backward compatibility not guaranteed
|
||||
return self._bad_words_token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user