[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)

This commit is contained in:
Cade Daniel
2024-04-23 01:02:36 -07:00
committed by GitHub
parent 050f285ff6
commit 62b8aebc6f
22 changed files with 1164 additions and 175 deletions

View File

@@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Dict, List, Protocol
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__)
disable_created_metrics()
@@ -118,6 +121,8 @@ class Stats:
time_per_output_tokens: List[float]
time_e2e_requests: List[float]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
@@ -235,3 +240,19 @@ class StatLogger:
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")