[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Iterable, List, Optional, Type, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -11,6 +11,10 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.engine.ray_utils import initialize_ray_cluster
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
@@ -18,8 +22,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
SequenceGroup)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@@ -187,6 +190,21 @@ class LLMEngine:
|
||||
labels=dict(model_name=model_config.model))
|
||||
self.stat_logger.info("cache_config", self.cache_config)
|
||||
|
||||
# Create sequence output processor, e.g. for beam search or
|
||||
# speculative decoding.
|
||||
self.output_processor = (
|
||||
SequenceGroupOutputProcessor.create_output_processor(
|
||||
self.scheduler_config,
|
||||
self.detokenizer,
|
||||
self.scheduler,
|
||||
self.seq_counter,
|
||||
self.get_tokenizer_for_seq,
|
||||
stop_checker=StopChecker(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.get_tokenizer_for_seq,
|
||||
),
|
||||
))
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@@ -412,240 +430,32 @@ class LLMEngine:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
early_stopping: Union[bool, str],
|
||||
sampling_params: SamplingParams,
|
||||
best_running_seq: Sequence,
|
||||
current_worst_seq: Sequence,
|
||||
) -> bool:
|
||||
assert sampling_params.use_beam_search
|
||||
length_penalty = sampling_params.length_penalty
|
||||
if early_stopping is True:
|
||||
return True
|
||||
|
||||
current_worst_score = current_worst_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=current_worst_seq.eos_token_id)
|
||||
if early_stopping is False:
|
||||
highest_attainable_score = best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id)
|
||||
else:
|
||||
assert early_stopping == "never"
|
||||
if length_penalty > 0.0:
|
||||
# If length_penalty > 0.0, beam search will prefer longer
|
||||
# sequences. The highest attainable score calculation is
|
||||
# based on the longest possible sequence length in this case.
|
||||
max_possible_length = max(
|
||||
best_running_seq.get_prompt_len() +
|
||||
sampling_params.max_tokens,
|
||||
self.scheduler_config.max_model_len)
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id,
|
||||
seq_len=max_possible_length))
|
||||
else:
|
||||
# Otherwise, beam search will prefer shorter sequences. The
|
||||
# highest attainable score calculation is based on the current
|
||||
# sequence length.
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutput) -> None:
|
||||
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
seq_group.prompt_logprobs = prompt_logprobs
|
||||
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
for sample in samples:
|
||||
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||
# List of (child, parent)
|
||||
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutput] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
# the parent sequence from the sequence group since it will
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
self.scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
child_seqs.append((child, parent))
|
||||
# Continue the parent sequence for the last child sample.
|
||||
# We reuse the parent sequence here to reduce redundant memory
|
||||
# copies, especially when using non-beam search sampling methods.
|
||||
last_child_sample = child_samples[-1]
|
||||
parent.append_token_id(last_child_sample.output_token,
|
||||
last_child_sample.logprobs)
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
if seq_group.sampling_params.detokenize:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, seq_group.sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self._check_stop(seq, new_char_count, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
if not seq_group.sampling_params.use_beam_search:
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
# NOTE: we need to fork the new sequences before freeing the
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# Beam search case
|
||||
# Select the child sequences to keep in the sequence group.
|
||||
selected_child_seqs = []
|
||||
unselected_child_seqs = []
|
||||
beam_width = seq_group.sampling_params.best_of
|
||||
length_penalty = seq_group.sampling_params.length_penalty
|
||||
|
||||
# Select the newly finished sequences with the highest scores
|
||||
# to replace existing finished sequences.
|
||||
# Tuple of (seq, parent, is_new)
|
||||
existing_finished_seqs = [(seq, None, False)
|
||||
for seq in existing_finished_seqs]
|
||||
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
||||
if seq.is_finished()]
|
||||
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||
# Sort the finished sequences by their scores.
|
||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
||||
reverse=True)
|
||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes and has a high
|
||||
# score, so we will add it into the sequence group.
|
||||
selected_child_seqs.append((seq, parent))
|
||||
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes but has a low
|
||||
# score, so we will not add it into the sequence group.
|
||||
# Additionally, if this sequence is a continuation of a
|
||||
# parent sequence, we will need remove the parent sequence
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.append((seq, parent))
|
||||
else:
|
||||
# An existing finished sequence has a low score, so we will
|
||||
# remove it from the sequence group.
|
||||
seq_group.remove(seq.seq_id)
|
||||
|
||||
# select the top beam_width sequences from the running
|
||||
# sequences for the next iteration to continue the beam
|
||||
# search.
|
||||
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
||||
if not seq.is_finished()]
|
||||
# Sort the running sequences by their scores.
|
||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
||||
reverse=True)
|
||||
|
||||
# Check if we can stop the beam search.
|
||||
if len(running_child_seqs) == 0:
|
||||
# No running sequences, stop the beam search.
|
||||
stop_beam_search = True
|
||||
elif len(all_finished_seqs) < beam_width:
|
||||
# Not enough finished sequences, continue the beam search.
|
||||
stop_beam_search = False
|
||||
else:
|
||||
# Check the early stopping criteria
|
||||
best_running_seq = running_child_seqs[0][0]
|
||||
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||
stop_beam_search = self._check_beam_search_early_stopping(
|
||||
seq_group.sampling_params.early_stopping,
|
||||
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
||||
|
||||
if stop_beam_search:
|
||||
# Stop the beam search and remove all the running sequences from
|
||||
# the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs)
|
||||
else:
|
||||
# Continue the beam search and select the top beam_width sequences
|
||||
# to continue the beam search.
|
||||
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||
# The remaining running sequences will not be used in the next
|
||||
# iteration. Again, if these sequences are continuations of
|
||||
# parent sequences, we will need to remove the parent sequences
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
# Remove the unselected parent sequences from the sequence group and
|
||||
# free their memory in block manager.
|
||||
for seq, parent in unselected_child_seqs:
|
||||
if seq is parent:
|
||||
# Remove the parent sequence if it is not selected for next
|
||||
# iteration
|
||||
seq_group.remove(seq.seq_id)
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
def _process_model_outputs(
|
||||
self, output: SamplerOutput,
|
||||
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||
self, output: List[SamplerOutput],
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
"""
|
||||
|
||||
now = time.time()
|
||||
|
||||
# Organize outputs by [sequence group][step] instead of
|
||||
# [step][sequence group].
|
||||
output_by_sequence_group = create_output_by_sequence_group(
|
||||
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
||||
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
|
||||
output_by_sequence_group):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
# If uncomputed tokens > 0, it means prefill is chunked.
|
||||
# We don't need to process outputs in that case.
|
||||
if seq_group.get_num_uncomputed_tokens() == 0:
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
self.output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
@@ -657,13 +467,9 @@ class LLMEngine:
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||
for seq_group in ignored_seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||
return request_outputs
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
@@ -721,13 +527,23 @@ class LLMEngine:
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
output = self.model_executor.execute_model(
|
||||
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
|
||||
scheduler_outputs.blocks_to_swap_out,
|
||||
scheduler_outputs.blocks_to_copy)
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
||||
else:
|
||||
output = []
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups)
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||
|
||||
return request_outputs
|
||||
|
||||
def do_log_stats(self) -> None:
|
||||
"""Forced log when no requests active."""
|
||||
@@ -807,87 +623,6 @@ class LLMEngine:
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
)
|
||||
|
||||
def _check_stop(self, seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences.
|
||||
|
||||
new_char_count is the number of chars added to the
|
||||
sequence's output text for the newly generated token
|
||||
"""
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
# Remove last token
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
|
||||
# Check if any stop strings are matched.
|
||||
stop_str = self._check_stop_strings(seq, new_char_count,
|
||||
sampling_params)
|
||||
if stop_str is not None:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _check_stop_strings(seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> Optional[str]:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
Returns the stop string if matched or else None.
|
||||
"""
|
||||
if not new_char_count:
|
||||
return None
|
||||
|
||||
for stop_str in sampling_params.stop:
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = seq.output_text.find(
|
||||
stop_str, -new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
if sampling_params.include_stop_str_in_output:
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(seq.output_text):
|
||||
# No truncation required.
|
||||
return stop_str
|
||||
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
seq.output_text = seq.output_text[:stop_index]
|
||||
return stop_str
|
||||
return None
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user