[core] remove beam search from the core (#9105)

This commit is contained in:
youkaichao
2024-10-06 22:47:04 -07:00
committed by GitHub
parent c8f26bb636
commit 18b296fdb2
25 changed files with 98 additions and 596 deletions

View File

@@ -28,7 +28,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
is_list_of)
logger = init_logger(__name__)
@@ -404,6 +405,12 @@ class LLM:
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
@@ -466,7 +473,7 @@ class LLM:
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob,
key=sort_beams_key,
reverse=True)
instance.beams = sorted_beams[:beam_width]
@@ -474,7 +481,7 @@ class LLM:
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob,
key=sort_beams_key,
reverse=True)
best_beams = sorted_completed[:beam_width]