[misc] hide best_of from engine (#9261)

Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
This commit is contained in:
youkaichao
2024-10-10 21:30:44 -07:00
committed by GitHub
parent 94bf9ae4e9
commit cbc2ef5529
14 changed files with 46 additions and 73 deletions

View File

@@ -106,9 +106,8 @@ class SamplingParams(
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as
the beam width when `use_beam_search` is True. By default, `best_of`
is set to `n`.
`best_of` must be greater than or equal to `n`. By default,
`best_of` is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
@@ -173,6 +172,7 @@ class SamplingParams(
n: int = 1
best_of: Optional[int] = None
_real_n: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
@@ -282,7 +282,19 @@ class SamplingParams(
)
def __post_init__(self) -> None:
self.best_of = self.best_of or self.n
# how we deal with `best_of``:
# if `best_of`` is not set, we default to `n`;
# if `best_of`` is set, we set `n`` to `best_of`,
# and set `_real_n`` to the original `n`.
# when we return the result, we will check
# if we need to return `n` or `_real_n` results
if self.best_of:
if self.best_of < self.n:
raise ValueError(
f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
self._real_n = self.n
self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
@@ -329,12 +341,6 @@ class SamplingParams(
f"type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if not isinstance(self.best_of, int):
raise ValueError(f"best_of must be an int, but is of "
f"type {type(self.best_of)}")
if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
@@ -385,7 +391,7 @@ class SamplingParams(
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == (
if self.best_of != self._real_n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
@@ -393,10 +399,6 @@ class SamplingParams(
if self.n > 1:
raise ValueError("n must be 1 when using greedy sampling, "
f"got {self.n}.")
assert isinstance(self.best_of, int)
if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling, "
f"got {self.best_of}.")
def update_from_generation_config(
self,
@@ -453,7 +455,6 @@ class SamplingParams(
def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "