[Chore] Cleanup guided namespace, move to structured outputs config (#22772)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,13 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
from pydantic import BaseModel
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
@@ -28,60 +28,35 @@ class SamplingType(IntEnum):
|
||||
|
||||
# maybe make msgspec?
|
||||
@dataclass
|
||||
class GuidedDecodingParams:
|
||||
"""One of these fields will be used to build a logit processor."""
|
||||
class StructuredOutputsParams:
|
||||
# One of these fields will be used to build a logit processor.
|
||||
json: Optional[Union[str, dict]] = None
|
||||
regex: Optional[str] = None
|
||||
choice: Optional[list[str]] = None
|
||||
grammar: Optional[str] = None
|
||||
json_object: Optional[bool] = None
|
||||
"""These are other options that can be set"""
|
||||
backend: Optional[str] = None
|
||||
backend_was_auto: bool = False
|
||||
# These are other options that can be set.
|
||||
disable_fallback: bool = False
|
||||
disable_any_whitespace: bool = False
|
||||
disable_additional_properties: bool = False
|
||||
whitespace_pattern: Optional[str] = None
|
||||
structural_tag: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
json: Optional[Union[dict, BaseModel, str]] = None,
|
||||
regex: Optional[str] = None,
|
||||
choice: Optional[list[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
json_object: Optional[bool] = None,
|
||||
backend: Optional[str] = None,
|
||||
whitespace_pattern: Optional[str] = None,
|
||||
structural_tag: Optional[str] = None,
|
||||
) -> Optional["GuidedDecodingParams"]:
|
||||
if all(arg is None for arg in (json, regex, choice, grammar,
|
||||
json_object, structural_tag)):
|
||||
return None
|
||||
# Extract json schemas from pydantic models
|
||||
if isinstance(json, (BaseModel, type(BaseModel))):
|
||||
json = json.model_json_schema()
|
||||
return GuidedDecodingParams(
|
||||
json=json,
|
||||
regex=regex,
|
||||
choice=choice,
|
||||
grammar=grammar,
|
||||
json_object=json_object,
|
||||
backend=backend,
|
||||
whitespace_pattern=whitespace_pattern,
|
||||
structural_tag=structural_tag,
|
||||
)
|
||||
_backend: Optional[str] = field(default=None, init=False)
|
||||
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
||||
_backend_was_auto: bool = field(default=False, init=False)
|
||||
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
guide_count = sum([
|
||||
count = sum([
|
||||
self.json is not None, self.regex is not None, self.choice
|
||||
is not None, self.grammar is not None, self.json_object is not None
|
||||
])
|
||||
if guide_count > 1:
|
||||
if count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding but multiple are "
|
||||
f"specified: {self.__dict__}")
|
||||
"You can only use one kind of structured outputs constraint "
|
||||
f"but multiple are specified: {self.__dict__}")
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
@@ -196,9 +171,8 @@ class SamplingParams(
|
||||
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
# Fields used to construct logits processors
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||
"""If provided, the engine will construct a guided decoding logits
|
||||
processor from these parameters."""
|
||||
structured_outputs: Optional[StructuredOutputsParams] = None
|
||||
"""Parameters for configuring structured outputs."""
|
||||
logit_bias: Optional[dict[int, float]] = None
|
||||
"""If provided, the engine will construct a logits processor that applies
|
||||
these logit biases."""
|
||||
@@ -246,7 +220,7 @@ class SamplingParams(
|
||||
msgspec.Meta(
|
||||
ge=-1)]] = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||
structured_outputs: Optional[StructuredOutputsParams] = None,
|
||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
||||
allowed_token_ids: Optional[list[int]] = None,
|
||||
extra_args: Optional[dict[str, Any]] = None,
|
||||
@@ -288,7 +262,7 @@ class SamplingParams(
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
output_kind=output_kind,
|
||||
guided_decoding=guided_decoding,
|
||||
structured_outputs=structured_outputs,
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
extra_args=extra_args,
|
||||
@@ -559,7 +533,7 @@ class SamplingParams(
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
|
||||
f"guided_decoding={self.guided_decoding}, "
|
||||
f"structured_outputs={self.structured_outputs}, "
|
||||
f"extra_args={self.extra_args})")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user