[Frontend][Core] Move guided decoding params into sampling params (#8252)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Joe Runde
2024-09-30 19:34:25 -06:00
committed by GitHub
parent bce324487a
commit 062c89e7c9
16 changed files with 441 additions and 281 deletions

View File

@@ -1,11 +1,13 @@
"""Sampling parameters for text generation."""
import copy
from dataclasses import dataclass
from enum import Enum, IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union
import msgspec
import torch
from pydantic import BaseModel
from typing_extensions import Annotated
import vllm.envs as envs
@@ -34,6 +36,54 @@ first argument, and returns a modified tensor of logits
to sample from."""
# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
"""One of these fields will be used to build a logit processor."""
json: Optional[Union[str, Dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
"""These are other options that can be set"""
backend: Optional[str] = None
whitespace_pattern: Optional[str] = None
@staticmethod
def from_optional(
json: Optional[Union[Dict, BaseModel, str]],
regex: Optional[str] = None,
choice: Optional[List[str]] = None,
grammar: Optional[str] = None,
json_object: Optional[bool] = None,
backend: Optional[str] = None,
whitespace_pattern: Optional[str] = None,
) -> "GuidedDecodingParams":
# Extract json schemas from pydantic models
if isinstance(json, (BaseModel, type(BaseModel))):
json = json.model_json_schema()
return GuidedDecodingParams(
json=json,
regex=regex,
choice=choice,
grammar=grammar,
json_object=json_object,
backend=backend,
whitespace_pattern=whitespace_pattern,
)
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
self.json is not None, self.regex is not None, self.choice
is not None, self.grammar is not None, self.json_object is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
@@ -124,6 +174,13 @@ class SamplingParams(
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
guided_decoding: If provided, the engine will construct a guided
decoding logits processor from these parameters. Defaults to None.
logit_bias: If provided, the engine will construct a logits processor
that applies these logit biases. Defaults to None.
allowed_token_ids: If provided, the engine will construct a logits
processor which only retains scores for the given token ids.
Defaults to None.
"""
n: int = 1
@@ -164,6 +221,11 @@ class SamplingParams(
output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors
guided_decoding: Optional[GuidedDecodingParams] = None
logit_bias: Optional[Dict[int, float]] = None
allowed_token_ids: Optional[List[int]] = None
@staticmethod
def from_optional(
n: Optional[int] = 1,
@@ -194,7 +256,16 @@ class SamplingParams(
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
allowed_token_ids: Optional[List[int]] = None,
) -> "SamplingParams":
if logit_bias is not None:
logit_bias = {
int(token): bias
for token, bias in logit_bias.items()
}
return SamplingParams(
n=1 if n is None else n,
best_of=best_of,
@@ -226,6 +297,9 @@ class SamplingParams(
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
guided_decoding=guided_decoding,
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
)
def __post_init__(self) -> None:
@@ -454,4 +528,5 @@ class SamplingParams(
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
f"guided_decoding={self.guided_decoding}")