Align vLLM's beam search implementation with HF generate (#857)
This commit is contained in:
@@ -34,6 +34,15 @@ class SamplingParams:
|
||||
top_k: Integer that controls the number of top tokens to consider. Set
|
||||
to -1 to consider all tokens.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
length_penalty: Float that penalizes sequences based on their length.
|
||||
Used in beam search.
|
||||
early_stopping: Controls the stopping condition for beam search. It
|
||||
accepts the following values: `True`, where the generation stops as
|
||||
soon as there are `best_of` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very
|
||||
unlikely to find better candidates; `"never"`, where the beam search
|
||||
procedure only stops when there cannot be better candidates
|
||||
(canonical beam search algorithm).
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
@@ -52,6 +61,8 @@ class SamplingParams:
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Union[None, str, List[str]] = None,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
@@ -65,6 +76,8 @@ class SamplingParams:
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.use_beam_search = use_beam_search
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
if stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(stop, str):
|
||||
@@ -78,9 +91,11 @@ class SamplingParams:
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
elif self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self._verify_greedy_sampling()
|
||||
else:
|
||||
self._verify_non_beam_search()
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
@@ -119,6 +134,20 @@ class SamplingParams:
|
||||
raise ValueError("top_p must be 1 when using beam search.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using beam search.")
|
||||
if self.early_stopping not in [True, False, "never"]:
|
||||
raise ValueError(
|
||||
f"early_stopping must be True, False, or 'never', "
|
||||
f"got {self.early_stopping}.")
|
||||
|
||||
def _verify_non_beam_search(self) -> None:
|
||||
if self.early_stopping is not False:
|
||||
raise ValueError("early_stopping is not effective and must be "
|
||||
"False when not using beam search.")
|
||||
if (self.length_penalty < 1.0 - _SAMPLING_EPS
|
||||
or self.length_penalty > 1.0 + _SAMPLING_EPS):
|
||||
raise ValueError(
|
||||
"length_penalty is not effective and must be the "
|
||||
"default value of 1.0 when not using beam search.")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.best_of > 1:
|
||||
@@ -138,6 +167,8 @@ class SamplingParams:
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
f"stop={self.stop}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
|
||||
Reference in New Issue
Block a user