[core] remove beam search from the core (#9105)
This commit is contained in:
@@ -33,7 +33,7 @@ from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
||||
random_uuid, weak_bind)
|
||||
get_beam_search_score, random_uuid, weak_bind)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
@@ -1050,6 +1050,12 @@ class AsyncLLMEngine:
|
||||
max_tokens = params.max_tokens
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(x.tokens, x.cum_logprob,
|
||||
tokenizer.eos_token_id,
|
||||
length_penalty)
|
||||
|
||||
tokenizer = await self.get_tokenizer()
|
||||
tokenizedPrompt = prompt if isinstance(
|
||||
@@ -1103,15 +1109,11 @@ class AsyncLLMEngine:
|
||||
else:
|
||||
new_beams.append(new_beam)
|
||||
|
||||
sorted_beams = sorted(new_beams,
|
||||
key=lambda x: x.cum_logprob,
|
||||
reverse=True)
|
||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||
all_beams = sorted_beams[:beam_width]
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed,
|
||||
key=lambda x: x.cum_logprob,
|
||||
reverse=True)
|
||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
|
||||
Reference in New Issue
Block a user