[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||
@@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
disable_created_metrics()
|
||||
@@ -118,6 +121,8 @@ class Stats:
|
||||
time_per_output_tokens: List[float]
|
||||
time_e2e_requests: List[float]
|
||||
|
||||
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
|
||||
class SupportsMetricsInfo(Protocol):
|
||||
|
||||
@@ -235,3 +240,19 @@ class StatLogger:
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
|
||||
if stats.spec_decode_metrics is not None:
|
||||
logger.info(
|
||||
self._format_spec_decode_metrics_str(
|
||||
stats.spec_decode_metrics))
|
||||
|
||||
def _format_spec_decode_metrics_str(
|
||||
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
||||
|
||||
return ("Speculative metrics: "
|
||||
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
|
||||
f"System efficiency: {metrics.system_efficiency:.3f}, "
|
||||
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
|
||||
f"Number of accepted tokens: {metrics.accepted_tokens}, "
|
||||
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
|
||||
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
|
||||
|
||||
Reference in New Issue
Block a user