[core] remove beam search from the core (#9105)
This commit is contained in:
@@ -28,7 +28,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
||||
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
|
||||
is_list_of)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -404,6 +405,12 @@ class LLM:
|
||||
max_tokens = params.max_tokens
|
||||
temperature = params.temperature
|
||||
ignore_eos = params.ignore_eos
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(x.tokens, x.cum_logprob,
|
||||
tokenizer.eos_token_id,
|
||||
length_penalty)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
@@ -466,7 +473,7 @@ class LLM:
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
sorted_beams = sorted(instance_new_beams,
|
||||
key=lambda x: x.cum_logprob,
|
||||
key=sort_beams_key,
|
||||
reverse=True)
|
||||
instance.beams = sorted_beams[:beam_width]
|
||||
|
||||
@@ -474,7 +481,7 @@ class LLM:
|
||||
for instance in instances:
|
||||
instance.completed.extend(instance.beams)
|
||||
sorted_completed = sorted(instance.completed,
|
||||
key=lambda x: x.cum_logprob,
|
||||
key=sort_beams_key,
|
||||
reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user