[Frontend] API support for beam search (#9087)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
@@ -22,8 +22,8 @@ from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
@@ -394,10 +394,7 @@ class LLM:
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: List[Union[str, List[int]]],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
ignore_eos: bool = False,
|
||||
temperature: float = 0.0,
|
||||
params: BeamSearchParams,
|
||||
) -> List[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
@@ -405,14 +402,17 @@ class LLM:
|
||||
Args:
|
||||
prompts: A list of prompts. Each prompt can be a string or a list
|
||||
of token IDs.
|
||||
beam_width: The number of beams to keep at each step.
|
||||
max_tokens: The max number of tokens to generate for each prompt.
|
||||
temperature: The temperature to use for generation.
|
||||
|
||||
params: The beam search parameters.
|
||||
|
||||
TODO: how does beam search work together with length penalty, frequency
|
||||
penalty, and stopping criteria, etc.?
|
||||
"""
|
||||
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
temperature = params.temperature
|
||||
ignore_eos = params.ignore_eos
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
|
||||
Reference in New Issue
Block a user