[V0 Deprecation] Remove best_of (#29090)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -144,12 +144,6 @@ class SamplingParams(
|
||||
are generated and streamed cumulatively per request. To see all `n`
|
||||
outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
|
||||
in `SamplingParams`."""
|
||||
best_of: int | None = None
|
||||
"""Number of output sequences that are generated from the prompt. From
|
||||
these `best_of` sequences, the top `n` sequences are returned. `best_of`
|
||||
must be greater than or equal to `n`. By default, `best_of` is set to `n`.
|
||||
Warning, this is only supported in V0."""
|
||||
_real_n: int | None = None
|
||||
presence_penalty: float = 0.0
|
||||
"""Penalizes new tokens based on whether they appear in the generated text
|
||||
so far. Values > 0 encourage the model to use new tokens, while values < 0
|
||||
@@ -265,7 +259,6 @@ class SamplingParams(
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: int | None = 1,
|
||||
best_of: int | None = None,
|
||||
presence_penalty: float | None = 0.0,
|
||||
frequency_penalty: float | None = 0.0,
|
||||
repetition_penalty: float | None = 1.0,
|
||||
@@ -315,7 +308,6 @@ class SamplingParams(
|
||||
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=0.0 if presence_penalty is None else presence_penalty,
|
||||
frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
|
||||
repetition_penalty=1.0
|
||||
@@ -348,22 +340,6 @@ class SamplingParams(
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# how we deal with `best_of`:
|
||||
# if `best_of` is not set, we default to `n`;
|
||||
# if `best_of` is set, we set `n` to `best_of`,
|
||||
# and set `_real_n` to the original `n`.
|
||||
# when we return the result, we will check
|
||||
# if we need to return `n` or `_real_n` results
|
||||
if self.best_of:
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(
|
||||
f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}."
|
||||
)
|
||||
if not self._real_n:
|
||||
self._real_n = self.n
|
||||
self.n = self.best_of
|
||||
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
@@ -433,18 +409,6 @@ class SamplingParams(
|
||||
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if self.best_of is not None:
|
||||
if not isinstance(self.best_of, int):
|
||||
raise ValueError(
|
||||
f"best_of must be an integer, got {type(self.best_of)}"
|
||||
)
|
||||
if self.best_of < 1:
|
||||
raise ValueError(f"best_of must be at least 1, got {self.best_of}")
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(
|
||||
f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}."
|
||||
)
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
|
||||
@@ -519,10 +483,6 @@ class SamplingParams(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
"Set detokenize=True to use stop."
|
||||
)
|
||||
if self.best_of != self._real_n and self.output_kind == (
|
||||
RequestOutputKind.DELTA
|
||||
):
|
||||
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.n > 1:
|
||||
|
||||
Reference in New Issue
Block a user