[Frontend] API support for beam search (#9087)

Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
Brendan Wong
2024-10-05 23:39:03 -07:00
committed by GitHub
parent 23fea8714a
commit 168cab6bbf
12 changed files with 275 additions and 68 deletions

View File

@@ -22,8 +22,8 @@ from vllm.model_executor.guided_decoding.guided_fields import (
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
SamplingParams)
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@@ -394,10 +394,7 @@ class LLM:
def beam_search(
self,
prompts: List[Union[str, List[int]]],
beam_width: int,
max_tokens: int,
ignore_eos: bool = False,
temperature: float = 0.0,
params: BeamSearchParams,
) -> List[BeamSearchOutput]:
"""
Generate sequences using beam search.
@@ -405,14 +402,17 @@ class LLM:
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
temperature: The temperature to use for generation.
params: The beam search parameters.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation