[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup)
|
||||
SequenceGroup, SequenceStage)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@@ -480,9 +480,12 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
# If uncomputed tokens > 0, it means prefill is chunked.
|
||||
# We don't need to process outputs in that case.
|
||||
if seq_group.get_num_uncomputed_tokens() == 0:
|
||||
|
||||
# If all sequences in the sequence group are in DECODE, then we can
|
||||
# process the output tokens. Otherwise, they are (chunked) prefill
|
||||
# samples and should not be processed.
|
||||
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
|
||||
if all(stage == SequenceStage.DECODE for stage in stages):
|
||||
self.output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
@@ -569,7 +572,8 @@ class LLMEngine:
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||
self.stat_logger.log(
|
||||
self._get_stats(scheduler_outputs, model_output=output))
|
||||
|
||||
return request_outputs
|
||||
|
||||
@@ -578,9 +582,18 @@ class LLMEngine:
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
|
||||
|
||||
def _get_stats(self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
|
||||
"""Get Stats to be Logged to Prometheus."""
|
||||
def _get_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs],
|
||||
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
|
||||
"""Get Stats to be Logged to Prometheus.
|
||||
|
||||
Args:
|
||||
scheduler_outputs: Optional, used to populate metrics related to
|
||||
the scheduled batch,
|
||||
model_output: Optional, used to emit speculative decoding metrics
|
||||
which are created by the workers.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# KV Cache Usage in %.
|
||||
@@ -637,6 +650,14 @@ class LLMEngine:
|
||||
time_to_first_tokens = time_last_iters if prompt_run else []
|
||||
time_per_output_tokens = [] if prompt_run else time_last_iters
|
||||
|
||||
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||
# sampler output.
|
||||
if model_output and (model_output[0].spec_decode_worker_metrics
|
||||
is not None):
|
||||
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
||||
else:
|
||||
spec_decode_metrics = None
|
||||
|
||||
return Stats(
|
||||
now=now,
|
||||
num_running=num_running,
|
||||
@@ -649,6 +670,7 @@ class LLMEngine:
|
||||
time_to_first_tokens=time_to_first_tokens,
|
||||
time_per_output_tokens=time_per_output_tokens,
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
spec_decode_metrics=spec_decode_metrics,
|
||||
)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user