[Frontend][Core] Move guided decoding params into sampling params (#8252)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -34,6 +36,54 @@ first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
# maybe make msgspec?
|
||||
@dataclass
|
||||
class GuidedDecodingParams:
|
||||
"""One of these fields will be used to build a logit processor."""
|
||||
json: Optional[Union[str, Dict]] = None
|
||||
regex: Optional[str] = None
|
||||
choice: Optional[List[str]] = None
|
||||
grammar: Optional[str] = None
|
||||
json_object: Optional[bool] = None
|
||||
"""These are other options that can be set"""
|
||||
backend: Optional[str] = None
|
||||
whitespace_pattern: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
json: Optional[Union[Dict, BaseModel, str]],
|
||||
regex: Optional[str] = None,
|
||||
choice: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
json_object: Optional[bool] = None,
|
||||
backend: Optional[str] = None,
|
||||
whitespace_pattern: Optional[str] = None,
|
||||
) -> "GuidedDecodingParams":
|
||||
# Extract json schemas from pydantic models
|
||||
if isinstance(json, (BaseModel, type(BaseModel))):
|
||||
json = json.model_json_schema()
|
||||
return GuidedDecodingParams(
|
||||
json=json,
|
||||
regex=regex,
|
||||
choice=choice,
|
||||
grammar=grammar,
|
||||
json_object=json_object,
|
||||
backend=backend,
|
||||
whitespace_pattern=whitespace_pattern,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
guide_count = sum([
|
||||
self.json is not None, self.regex is not None, self.choice
|
||||
is not None, self.grammar is not None, self.json_object is not None
|
||||
])
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding but multiple are "
|
||||
f"specified: {self.__dict__}")
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
@@ -124,6 +174,13 @@ class SamplingParams(
|
||||
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
||||
tokens from the prompt (i.e., left truncation). Defaults to None
|
||||
(i.e., no truncation).
|
||||
guided_decoding: If provided, the engine will construct a guided
|
||||
decoding logits processor from these parameters. Defaults to None.
|
||||
logit_bias: If provided, the engine will construct a logits processor
|
||||
that applies these logit biases. Defaults to None.
|
||||
allowed_token_ids: If provided, the engine will construct a logits
|
||||
processor which only retains scores for the given token ids.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
n: int = 1
|
||||
@@ -164,6 +221,11 @@ class SamplingParams(
|
||||
output_text_buffer_length: int = 0
|
||||
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
# Fields used to construct logits processors
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
allowed_token_ids: Optional[List[int]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: Optional[int] = 1,
|
||||
@@ -194,7 +256,16 @@ class SamplingParams(
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=1)]] = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
|
||||
allowed_token_ids: Optional[List[int]] = None,
|
||||
) -> "SamplingParams":
|
||||
if logit_bias is not None:
|
||||
logit_bias = {
|
||||
int(token): bias
|
||||
for token, bias in logit_bias.items()
|
||||
}
|
||||
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
@@ -226,6 +297,9 @@ class SamplingParams(
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
output_kind=output_kind,
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -454,4 +528,5 @@ class SamplingParams(
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
|
||||
f"guided_decoding={self.guided_decoding}")
|
||||
|
||||
Reference in New Issue
Block a user