[core] remove beam search from the core (#9105)
This commit is contained in:
@@ -10,7 +10,6 @@ import torch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -23,7 +22,6 @@ class SamplingType(IntEnum):
|
||||
GREEDY = 0
|
||||
RANDOM = 1
|
||||
RANDOM_SEED = 2
|
||||
BEAM = 3
|
||||
|
||||
|
||||
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
|
||||
@@ -134,16 +132,6 @@ class SamplingParams(
|
||||
considered, relative to the probability of the most likely token.
|
||||
Must be in [0, 1]. Set to 0 to disable this.
|
||||
seed: Random seed to use for the generation.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
length_penalty: Float that penalizes sequences based on their length.
|
||||
Used in beam search.
|
||||
early_stopping: Controls the stopping condition for beam search. It
|
||||
accepts the following values: `True`, where the generation stops as
|
||||
soon as there are `best_of` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very
|
||||
unlikely to find better candidates; `"never"`, where the beam search
|
||||
procedure only stops when there cannot be better candidates
|
||||
(canonical beam search algorithm).
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
stop_token_ids: List of tokens that stop the generation when they are
|
||||
@@ -193,9 +181,6 @@ class SamplingParams(
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
seed: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
length_penalty: float = 1.0
|
||||
early_stopping: Union[bool, str] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
ignore_eos: bool = False
|
||||
@@ -238,9 +223,6 @@ class SamplingParams(
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
@@ -280,9 +262,6 @@ class SamplingParams(
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=seed,
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
@@ -334,20 +313,13 @@ class SamplingParams(
|
||||
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH:
|
||||
raise ValueError(
|
||||
"Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa
|
||||
)
|
||||
self._verify_beam_search()
|
||||
else:
|
||||
self._verify_non_beam_search()
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self.top_p = 1.0
|
||||
self.top_k = -1
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self.top_p = 1.0
|
||||
self.top_k = -1
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
# eos_token_id is added to this by the engine
|
||||
self._all_stop_token_ids = set(self.stop_token_ids)
|
||||
|
||||
@@ -417,31 +389,6 @@ class SamplingParams(
|
||||
RequestOutputKind.DELTA):
|
||||
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
||||
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
raise ValueError("best_of must be greater than 1 when using beam "
|
||||
f"search. Got {self.best_of}.")
|
||||
if self.temperature > _SAMPLING_EPS:
|
||||
raise ValueError("temperature must be 0 when using beam search.")
|
||||
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||
raise ValueError("top_p must be 1 when using beam search.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using beam search.")
|
||||
if self.early_stopping not in [True, False, "never"]:
|
||||
raise ValueError(
|
||||
f"early_stopping must be True, False, or 'never', "
|
||||
f"got {self.early_stopping}.")
|
||||
|
||||
def _verify_non_beam_search(self) -> None:
|
||||
if self.early_stopping is not False:
|
||||
raise ValueError("early_stopping is not effective and must be "
|
||||
"False when not using beam search.")
|
||||
if (self.length_penalty < 1.0 - _SAMPLING_EPS
|
||||
or self.length_penalty > 1.0 + _SAMPLING_EPS):
|
||||
raise ValueError(
|
||||
"length_penalty is not effective and must be the "
|
||||
"default value of 1.0 when not using beam search.")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
assert isinstance(self.best_of, int)
|
||||
if self.best_of > 1:
|
||||
@@ -476,8 +423,6 @@ class SamplingParams(
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.use_beam_search:
|
||||
return SamplingType.BEAM
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
return SamplingType.GREEDY
|
||||
if self.seed is not None:
|
||||
@@ -514,9 +459,6 @@ class SamplingParams(
|
||||
f"top_k={self.top_k}, "
|
||||
f"min_p={self.min_p}, "
|
||||
f"seed={self.seed}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
f"stop={self.stop}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||
@@ -542,3 +484,4 @@ class BeamSearchParams(
|
||||
max_tokens: int
|
||||
ignore_eos: bool = False
|
||||
temperature: float = 0.0
|
||||
length_penalty: float = 1.0
|
||||
|
||||
Reference in New Issue
Block a user