[misc] hide best_of from engine (#9261)
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
This commit is contained in:
@@ -106,9 +106,8 @@ class SamplingParams(
|
||||
n: Number of output sequences to return for the given prompt.
|
||||
best_of: Number of output sequences that are generated from the prompt.
|
||||
From these `best_of` sequences, the top `n` sequences are returned.
|
||||
`best_of` must be greater than or equal to `n`. This is treated as
|
||||
the beam width when `use_beam_search` is True. By default, `best_of`
|
||||
is set to `n`.
|
||||
`best_of` must be greater than or equal to `n`. By default,
|
||||
`best_of` is set to `n`.
|
||||
presence_penalty: Float that penalizes new tokens based on whether they
|
||||
appear in the generated text so far. Values > 0 encourage the model
|
||||
to use new tokens, while values < 0 encourage the model to repeat
|
||||
@@ -173,6 +172,7 @@ class SamplingParams(
|
||||
|
||||
n: int = 1
|
||||
best_of: Optional[int] = None
|
||||
_real_n: Optional[int] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
@@ -282,7 +282,19 @@ class SamplingParams(
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.best_of = self.best_of or self.n
|
||||
# how we deal with `best_of``:
|
||||
# if `best_of`` is not set, we default to `n`;
|
||||
# if `best_of`` is set, we set `n`` to `best_of`,
|
||||
# and set `_real_n`` to the original `n`.
|
||||
# when we return the result, we will check
|
||||
# if we need to return `n` or `_real_n` results
|
||||
if self.best_of:
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(
|
||||
f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
self._real_n = self.n
|
||||
self.n = self.best_of
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
@@ -329,12 +341,6 @@ class SamplingParams(
|
||||
f"type {type(self.n)}")
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if not isinstance(self.best_of, int):
|
||||
raise ValueError(f"best_of must be an int, but is of "
|
||||
f"type {type(self.best_of)}")
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError("presence_penalty must be in [-2, 2], got "
|
||||
f"{self.presence_penalty}.")
|
||||
@@ -385,7 +391,7 @@ class SamplingParams(
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
"Set detokenize=True to use stop.")
|
||||
if self.best_of != self.n and self.output_kind == (
|
||||
if self.best_of != self._real_n and self.output_kind == (
|
||||
RequestOutputKind.DELTA):
|
||||
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
||||
|
||||
@@ -393,10 +399,6 @@ class SamplingParams(
|
||||
if self.n > 1:
|
||||
raise ValueError("n must be 1 when using greedy sampling, "
|
||||
f"got {self.n}.")
|
||||
assert isinstance(self.best_of, int)
|
||||
if self.best_of > 1:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling, "
|
||||
f"got {self.best_of}.")
|
||||
|
||||
def update_from_generation_config(
|
||||
self,
|
||||
@@ -453,7 +455,6 @@ class SamplingParams(
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"repetition_penalty={self.repetition_penalty}, "
|
||||
|
||||
Reference in New Issue
Block a user