[Core] Add engine option to return only deltas or final output (#7381)

This commit is contained in:
Nick Hill
2024-09-12 20:02:00 +01:00
committed by GitHub
parent a6c0f3658d
commit 551ce01078
10 changed files with 371 additions and 137 deletions

View File

@@ -1,6 +1,6 @@
"""Sampling parameters for text generation."""
import copy
from enum import IntEnum
from enum import Enum, IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union
@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
to sample from."""
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOuputs
FINAL_ONLY = 2
class SamplingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
@@ -147,6 +156,7 @@ class SamplingParams(
logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
# The below fields are not supposed to be used as an input.
# They are set in post_init.
@@ -182,6 +192,7 @@ class SamplingParams(
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
) -> "SamplingParams":
return SamplingParams(
n=1 if n is None else n,
@@ -213,6 +224,7 @@ class SamplingParams(
spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
)
def __post_init__(self) -> None:
@@ -317,6 +329,9 @@ class SamplingParams(
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None:
if self.best_of == 1: