re-implement beam search on top of vllm core (#8726)
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import itertools
|
||||
from contextlib import contextmanager
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
|
||||
overload)
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||
Union, cast, overload)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -30,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchSequence:
|
||||
"""A sequence for beam search.
|
||||
It keeps track of the tokens and the log probability of the sequence.
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
# The tokens includes the prompt.
|
||||
tokens: List[int]
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchOutput:
|
||||
"""The output of beam search.
|
||||
It contains the list of the best beam search sequences.
|
||||
The length of the list is equal to the beam width.
|
||||
"""
|
||||
sequences: List[BeamSearchSequence]
|
||||
|
||||
|
||||
class BeamSearchInstance:
|
||||
|
||||
def __init__(self, prompt_tokens: List[int]):
|
||||
self.beams: List[BeamSearchSequence] = [
|
||||
BeamSearchSequence(tokens=prompt_tokens)
|
||||
]
|
||||
self.completed: List[BeamSearchSequence] = []
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
|
||||
@@ -354,6 +387,105 @@ class LLM:
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: List[Union[str, List[int]]],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
ignore_eos: bool = False,
|
||||
) -> List[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts. Each prompt can be a string or a list
|
||||
of token IDs.
|
||||
beam_width: The number of beams to keep at each step.
|
||||
max_tokens: The max number of tokens to generate for each prompt.
|
||||
|
||||
TODO: how does beam search work together with length penalty, frequency
|
||||
penalty, and stopping criteria, etc.?
|
||||
"""
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=0.0)
|
||||
instances: List[BeamSearchInstance] = []
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_tokens = prompt if isinstance(
|
||||
prompt, list) else tokenizer.encode(prompt)
|
||||
instances.append(BeamSearchInstance(prompt_tokens))
|
||||
|
||||
for _ in range(max_tokens):
|
||||
all_beams: List[BeamSearchSequence] = list(
|
||||
sum((instance.beams for instance in instances), []))
|
||||
pos = [0] + list(
|
||||
itertools.accumulate(
|
||||
len(instance.beams) for instance in instances))
|
||||
instance_start_and_end: List[Tuple[int, int]] = list(
|
||||
zip(pos[:-1], pos[1:]))
|
||||
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
|
||||
prompts_batch = [
|
||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||
for beam in all_beams
|
||||
]
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self.generate(prompts_batch,
|
||||
sampling_params=beam_search_params,
|
||||
use_tqdm=False)
|
||||
|
||||
for (start, end), instance in zip(instance_start_and_end,
|
||||
instances):
|
||||
instance_new_beams = []
|
||||
for i in range(start, end):
|
||||
current_beam = all_beams[i]
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
# if `result.outputs[0].logprobs` is None, it means
|
||||
# the sequence is completed because of the max-model-len
|
||||
# or abortion. we don't need to add it to the new beams.
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
instance.completed.append(new_beam)
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
sorted_beams = sorted(instance_new_beams,
|
||||
key=lambda x: x.cum_logprob,
|
||||
reverse=True)
|
||||
instance.beams = sorted_beams[:beam_width]
|
||||
|
||||
outputs = []
|
||||
for instance in instances:
|
||||
instance.completed.extend(instance.beams)
|
||||
sorted_completed = sorted(instance.completed,
|
||||
key=lambda x: x.cum_logprob,
|
||||
reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens)
|
||||
outputs.append(BeamSearchOutput(sequences=best_beams))
|
||||
|
||||
return outputs
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
|
||||
Reference in New Issue
Block a user