Align vLLM's beam search implementation with HF generate (#857)
This commit is contained in:
@@ -11,7 +11,8 @@ from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
@@ -258,14 +259,11 @@ class LLMEngine:
|
||||
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.best_of):
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
seqs.append(seq)
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, seqs, sampling_params,
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
@@ -303,22 +301,230 @@ class LLMEngine:
|
||||
]
|
||||
return seq_group_metadata_list, scheduler_outputs, None
|
||||
|
||||
def _process_worker_outputs(
|
||||
self, output,
|
||||
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||
# Update the scheduler with the model outputs.
|
||||
seq_groups = self.scheduler.update(output)
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
early_stopping: Union[bool, str],
|
||||
sampling_params: SamplingParams,
|
||||
best_running_seq: Sequence,
|
||||
current_worst_seq: Sequence,
|
||||
) -> bool:
|
||||
assert sampling_params.use_beam_search
|
||||
length_penalty = sampling_params.length_penalty
|
||||
if early_stopping is True:
|
||||
return True
|
||||
|
||||
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
if early_stopping is False:
|
||||
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
else:
|
||||
assert early_stopping == "never"
|
||||
if length_penalty > 0.0:
|
||||
# If length_penalty > 0.0, beam search will prefer longer
|
||||
# sequences. The highest attainable score calculation is
|
||||
# based on the longest possible sequence length in this case.
|
||||
max_possible_length = max(
|
||||
best_running_seq.get_prompt_len() +
|
||||
sampling_params.max_tokens,
|
||||
self.scheduler_config.max_model_len)
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
seq_len=max_possible_length))
|
||||
else:
|
||||
# Otherwise, beam search will prefer shorter sequences. The
|
||||
# highest attainable score calculation is based on the current
|
||||
# sequence length.
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_samples(
|
||||
self, seq_group: SequenceGroup,
|
||||
samples: List[SequenceOutputs]) -> None:
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
for sample in samples:
|
||||
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||
# List of (child, parent)
|
||||
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutputs] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
# the parent sequence from the sequence group since it will
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
self.scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
child_seqs.append((child, parent))
|
||||
# Continue the parent sequence for the last child sample.
|
||||
# We reuse the parent sequence here to reduce redundant memory
|
||||
# copies, especially when using non-beam search sampling methods.
|
||||
last_child_sample = child_samples[-1]
|
||||
parent.append_token_id(last_child_sample.output_token,
|
||||
last_child_sample.logprobs)
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self._decode_sequence(seq)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
if not seq_group.sampling_params.use_beam_search:
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
# NOTE: we need to fork the new sequences before freeing the
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# Beam search case
|
||||
# Select the child sequences to keep in the sequence group.
|
||||
selected_child_seqs = []
|
||||
unselected_child_seqs = []
|
||||
beam_width = seq_group.sampling_params.best_of
|
||||
length_penalty = seq_group.sampling_params.length_penalty
|
||||
|
||||
# Select the newly finished sequences with the highest scores
|
||||
# to replace existing finished sequences.
|
||||
# Tuple of (seq, parent, is_new)
|
||||
existing_finished_seqs = [(seq, None, False)
|
||||
for seq in existing_finished_seqs]
|
||||
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
||||
if seq.is_finished()]
|
||||
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||
# Sort the finished sequences by their scores.
|
||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id),
|
||||
reverse=True)
|
||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes and has a high
|
||||
# score, so we will add it into the sequence group.
|
||||
selected_child_seqs.append((seq, parent))
|
||||
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes but has a low
|
||||
# score, so we will not add it into the sequence group.
|
||||
# Additionally, if this sequence is a continuation of a
|
||||
# parent sequence, we will need remove the parent sequence
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.append((seq, parent))
|
||||
else:
|
||||
# An existing finished sequence has a low score, so we will
|
||||
# remove it from the sequence group.
|
||||
seq_group.remove(seq.seq_id)
|
||||
|
||||
# select the top beam_width sequences from the running
|
||||
# sequences for the next iteration to continue the beam
|
||||
# search.
|
||||
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
||||
if not seq.is_finished()]
|
||||
# Sort the running sequences by their scores.
|
||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id),
|
||||
reverse=True)
|
||||
|
||||
# Check if we can stop the beam search.
|
||||
if len(running_child_seqs) == 0:
|
||||
# No running sequences, stop the beam search.
|
||||
stop_beam_search = True
|
||||
elif len(all_finished_seqs) < beam_width:
|
||||
# Not enough finished sequences, continue the beam search.
|
||||
stop_beam_search = False
|
||||
else:
|
||||
# Check the early stopping criteria
|
||||
best_running_seq = running_child_seqs[0][0]
|
||||
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||
stop_beam_search = self._check_beam_search_early_stopping(
|
||||
seq_group.sampling_params.early_stopping,
|
||||
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
||||
|
||||
if stop_beam_search:
|
||||
# Stop the beam search and remove all the running sequences from
|
||||
# the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs)
|
||||
else:
|
||||
# Continue the beam search and select the top beam_width sequences
|
||||
# to continue the beam search.
|
||||
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||
# The remaining running sequences will not be used in the next
|
||||
# iteration. Again, if these sequences are continuations of
|
||||
# parent sequences, we will need to remove the parent sequences
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
# Remove the unselected parent sequences from the sequence group and
|
||||
# free their memory in block manager.
|
||||
for seq, parent in unselected_child_seqs:
|
||||
if seq is parent:
|
||||
# Remove the parent sequence if it is not selected for next
|
||||
# iteration
|
||||
seq_group.remove(seq.seq_id)
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
def _process_model_outputs(
|
||||
self, output: SamplerOutput,
|
||||
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
for seq_group, samples in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_samples(seq_group, samples)
|
||||
|
||||
# Decode the sequences.
|
||||
self._decode_sequences(seq_groups)
|
||||
# Stop the sequences that meet the stopping criteria.
|
||||
self._stop_sequences(seq_groups)
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
|
||||
for seq_group in (scheduled_seq_groups +
|
||||
scheduler_outputs.ignored_seq_groups):
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
@@ -351,7 +557,7 @@ class LLMEngine:
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
|
||||
return self._process_worker_outputs(output, scheduler_outputs)
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
def _log_system_stats(
|
||||
self,
|
||||
@@ -416,55 +622,44 @@ class LLMEngine:
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
self.last_logging_time = now
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
"""Decodes the sequence outputs."""
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if new_token is not None:
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
def _decode_sequence(self, seq: Sequence) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if new_token is not None:
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
|
||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
def _check_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences."""
|
||||
for seq_group in seq_groups:
|
||||
sampling_params = seq_group.sampling_params
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Check if the sequence has generated a stop string.
|
||||
stopped = False
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_STOPPED)
|
||||
stopped = True
|
||||
break
|
||||
if stopped:
|
||||
continue
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||
continue
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||
continue
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if not sampling_params.ignore_eos:
|
||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_STOPPED)
|
||||
continue
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user