[Spec Decode][UX] Add acceptance stats to vllm bench serve report (#31739)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Matthew Bonanni
2026-01-06 16:21:42 -05:00
committed by GitHub
parent dba95378a6
commit d49899732e

View File

@@ -86,6 +86,76 @@ async def get_first_model_from_server(
) from e
@dataclass
class SpecDecodeMetrics:
"""Speculative decoding metrics from the server's Prometheus endpoint."""
num_drafts: int
num_draft_tokens: int
num_accepted_tokens: int
accepted_per_pos: dict[int, int]
async def fetch_spec_decode_metrics(
base_url: str, session: aiohttp.ClientSession
) -> SpecDecodeMetrics | None:
"""Fetch speculative decoding metrics from the server's Prometheus endpoint.
Returns None if speculative decoding is not enabled or metrics are not available.
"""
metrics_url = f"{base_url}/metrics"
try:
async with session.get(metrics_url) as response:
if response.status != 200:
return None
text = await response.text()
num_drafts = 0
num_draft_tokens = 0
num_accepted_tokens = 0
accepted_per_pos: dict[int, int] = {}
found_spec_decode = False
for line in text.split("\n"):
line = line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("vllm:spec_decode"):
found_spec_decode = True
parts = line.split()
if parts:
with contextlib.suppress(ValueError):
if "num_drafts" in line:
num_drafts += int(float(parts[-1]))
elif "num_draft_tokens" in line:
num_draft_tokens += int(float(parts[-1]))
elif "num_accepted_tokens_per_pos" in line:
pos_label = 'position="'
if pos_label in line:
start = line.index(pos_label) + len(pos_label)
end = line.index('"', start)
pos = int(line[start:end])
val = int(float(parts[-1]))
accepted_per_pos[pos] = (
accepted_per_pos.get(pos, 0) + val
)
elif "num_accepted_tokens" in line:
num_accepted_tokens += int(float(parts[-1]))
if not found_spec_decode:
return None
return SpecDecodeMetrics(
num_drafts=num_drafts,
num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens,
accepted_per_pos=accepted_per_pos,
)
except (aiohttp.ClientError, asyncio.TimeoutError):
return None
class TaskType(Enum):
GENERATION = "generation"
POOLING = "pooling"
@@ -685,6 +755,8 @@ async def benchmark(
print(f"Burstiness factor: {burstiness} ({distribution})")
print(f"Maximum request concurrency: {max_concurrency}")
spec_decode_metrics_before = await fetch_spec_decode_metrics(base_url, session)
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
semaphore = (
@@ -768,6 +840,48 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time
spec_decode_metrics_after = await fetch_spec_decode_metrics(base_url, session)
spec_decode_stats: dict[str, Any] | None = None
if spec_decode_metrics_before is not None and spec_decode_metrics_after is not None:
delta_drafts = (
spec_decode_metrics_after.num_drafts - spec_decode_metrics_before.num_drafts
)
delta_draft_tokens = (
spec_decode_metrics_after.num_draft_tokens
- spec_decode_metrics_before.num_draft_tokens
)
delta_accepted = (
spec_decode_metrics_after.num_accepted_tokens
- spec_decode_metrics_before.num_accepted_tokens
)
per_pos_rates: list[float] = []
if delta_drafts > 0:
positions = sorted(
set(spec_decode_metrics_before.accepted_per_pos.keys())
| set(spec_decode_metrics_after.accepted_per_pos.keys())
)
for pos in positions:
before_val = spec_decode_metrics_before.accepted_per_pos.get(pos, 0)
after_val = spec_decode_metrics_after.accepted_per_pos.get(
pos, before_val
)
delta_pos = after_val - before_val
per_pos_rates.append(delta_pos / delta_drafts)
if delta_draft_tokens > 0:
acceptance_rate = (delta_accepted / delta_draft_tokens) * 100
acceptance_length = (
1 + delta_accepted / delta_drafts if delta_drafts > 0 else 0.0
)
spec_decode_stats = {
"num_drafts": delta_drafts,
"draft_tokens": delta_draft_tokens,
"accepted_tokens": delta_accepted,
"acceptance_rate": acceptance_rate,
"acceptance_length": acceptance_length,
"per_position_acceptance_rates": per_pos_rates,
}
if task_type == TaskType.GENERATION:
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
@@ -863,6 +977,18 @@ async def benchmark(
if rps_change_events:
result["rps_change_events"] = rps_change_events
if spec_decode_stats is not None:
result["spec_decode_acceptance_rate"] = spec_decode_stats["acceptance_rate"]
result["spec_decode_acceptance_length"] = spec_decode_stats["acceptance_length"]
result["spec_decode_num_drafts"] = int(spec_decode_stats["num_drafts"])
result["spec_decode_draft_tokens"] = int(spec_decode_stats["draft_tokens"])
result["spec_decode_accepted_tokens"] = int(
spec_decode_stats["accepted_tokens"]
)
result["spec_decode_per_position_acceptance_rates"] = spec_decode_stats.get(
"per_position_acceptance_rates", []
)
def process_one_metric(
# E.g., "ttft"
metric_attribute_name: str,
@@ -908,6 +1034,35 @@ async def benchmark(
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
if spec_decode_stats is not None:
print("{s:{c}^{n}}".format(s="Speculative Decoding", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
"Acceptance rate (%):", spec_decode_stats["acceptance_rate"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Acceptance length:", spec_decode_stats["acceptance_length"]
)
)
print("{:<40} {:<10}".format("Drafts:", int(spec_decode_stats["num_drafts"])))
print(
"{:<40} {:<10}".format(
"Draft tokens:", int(spec_decode_stats["draft_tokens"])
)
)
print(
"{:<40} {:<10}".format(
"Accepted tokens:", int(spec_decode_stats["accepted_tokens"])
)
)
per_pos = spec_decode_stats.get("per_position_acceptance_rates", [])
if per_pos:
print("Per-position acceptance (%):")
for i, rate in enumerate(per_pos):
print("{:<40} {:<10.2f}".format(f" Position {i}:", rate * 100))
print("=" * 50)
if profile: