[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)

This commit is contained in:
Cade Daniel
2024-04-23 01:02:36 -07:00
committed by GitHub
parent 050f285ff6
commit 62b8aebc6f
22 changed files with 1164 additions and 175 deletions

View File

@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup)
SequenceGroup, SequenceStage)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@@ -480,9 +480,12 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
# If all sequences in the sequence group are in DECODE, then we can
# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
@@ -569,7 +572,8 @@ class LLMEngine:
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))
self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output=output))
return request_outputs
@@ -578,9 +582,18 @@ class LLMEngine:
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
def _get_stats(self,
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
"""Get Stats to be Logged to Prometheus."""
def _get_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
"""Get Stats to be Logged to Prometheus.
Args:
scheduler_outputs: Optional, used to populate metrics related to
the scheduled batch,
model_output: Optional, used to emit speculative decoding metrics
which are created by the workers.
"""
now = time.time()
# KV Cache Usage in %.
@@ -637,6 +650,14 @@ class LLMEngine:
time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None
return Stats(
now=now,
num_running=num_running,
@@ -649,6 +670,7 @@ class LLMEngine:
time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests,
spec_decode_metrics=spec_decode_metrics,
)
def add_lora(self, lora_request: LoRARequest) -> bool: