Align vLLM's beam search implementation with HF generate (#857)

This commit is contained in:
Zhuohan Li
2023-09-04 17:29:42 -07:00
committed by GitHub
parent e15932bb60
commit 002800f081
24 changed files with 596 additions and 260 deletions

View File

@@ -34,6 +34,15 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating
@@ -52,6 +61,8 @@ class SamplingParams:
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
@@ -65,6 +76,8 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.stop = []
elif isinstance(stop, str):
@@ -78,9 +91,11 @@ class SamplingParams:
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
elif self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
else:
self._verify_non_beam_search()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
def _verify_args(self) -> None:
if self.n < 1:
@@ -119,6 +134,20 @@ class SamplingParams:
raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None:
if self.best_of > 1:
@@ -138,6 +167,8 @@ class SamplingParams:
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "