[Core] Add engine option to return only deltas or final output (#7381)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
# Return only deltas in each RequestOutput
|
||||
DELTA = 1
|
||||
# Do not return intermediate RequestOuputs
|
||||
FINAL_ONLY = 2
|
||||
|
||||
|
||||
class SamplingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
@@ -147,6 +156,7 @@ class SamplingParams(
|
||||
logits_processors: Optional[Any] = None
|
||||
include_stop_str_in_output: bool = False
|
||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
|
||||
|
||||
# The below fields are not supposed to be used as an input.
|
||||
# They are set in post_init.
|
||||
@@ -182,6 +192,7 @@ class SamplingParams(
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=1)]] = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
) -> "SamplingParams":
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
@@ -213,6 +224,7 @@ class SamplingParams(
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
output_kind=output_kind,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -317,6 +329,9 @@ class SamplingParams(
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
"Set detokenize=True to use stop.")
|
||||
if self.best_of != self.n and self.output_kind == (
|
||||
RequestOutputKind.DELTA):
|
||||
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
||||
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
|
||||
Reference in New Issue
Block a user