[Spec Decode][UX] Add acceptance stats to vllm bench serve report (#31739)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
@@ -86,6 +86,76 @@ async def get_first_model_from_server(
|
||||
) from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodeMetrics:
|
||||
"""Speculative decoding metrics from the server's Prometheus endpoint."""
|
||||
|
||||
num_drafts: int
|
||||
num_draft_tokens: int
|
||||
num_accepted_tokens: int
|
||||
accepted_per_pos: dict[int, int]
|
||||
|
||||
|
||||
async def fetch_spec_decode_metrics(
|
||||
base_url: str, session: aiohttp.ClientSession
|
||||
) -> SpecDecodeMetrics | None:
|
||||
"""Fetch speculative decoding metrics from the server's Prometheus endpoint.
|
||||
|
||||
Returns None if speculative decoding is not enabled or metrics are not available.
|
||||
"""
|
||||
metrics_url = f"{base_url}/metrics"
|
||||
try:
|
||||
async with session.get(metrics_url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
text = await response.text()
|
||||
|
||||
num_drafts = 0
|
||||
num_draft_tokens = 0
|
||||
num_accepted_tokens = 0
|
||||
accepted_per_pos: dict[int, int] = {}
|
||||
found_spec_decode = False
|
||||
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
if line.startswith("vllm:spec_decode"):
|
||||
found_spec_decode = True
|
||||
parts = line.split()
|
||||
if parts:
|
||||
with contextlib.suppress(ValueError):
|
||||
if "num_drafts" in line:
|
||||
num_drafts += int(float(parts[-1]))
|
||||
elif "num_draft_tokens" in line:
|
||||
num_draft_tokens += int(float(parts[-1]))
|
||||
elif "num_accepted_tokens_per_pos" in line:
|
||||
pos_label = 'position="'
|
||||
if pos_label in line:
|
||||
start = line.index(pos_label) + len(pos_label)
|
||||
end = line.index('"', start)
|
||||
pos = int(line[start:end])
|
||||
val = int(float(parts[-1]))
|
||||
accepted_per_pos[pos] = (
|
||||
accepted_per_pos.get(pos, 0) + val
|
||||
)
|
||||
elif "num_accepted_tokens" in line:
|
||||
num_accepted_tokens += int(float(parts[-1]))
|
||||
|
||||
if not found_spec_decode:
|
||||
return None
|
||||
|
||||
return SpecDecodeMetrics(
|
||||
num_drafts=num_drafts,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
accepted_per_pos=accepted_per_pos,
|
||||
)
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError):
|
||||
return None
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
GENERATION = "generation"
|
||||
POOLING = "pooling"
|
||||
@@ -685,6 +755,8 @@ async def benchmark(
|
||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||
print(f"Maximum request concurrency: {max_concurrency}")
|
||||
|
||||
spec_decode_metrics_before = await fetch_spec_decode_metrics(base_url, session)
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
semaphore = (
|
||||
@@ -768,6 +840,48 @@ async def benchmark(
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
spec_decode_metrics_after = await fetch_spec_decode_metrics(base_url, session)
|
||||
spec_decode_stats: dict[str, Any] | None = None
|
||||
if spec_decode_metrics_before is not None and spec_decode_metrics_after is not None:
|
||||
delta_drafts = (
|
||||
spec_decode_metrics_after.num_drafts - spec_decode_metrics_before.num_drafts
|
||||
)
|
||||
delta_draft_tokens = (
|
||||
spec_decode_metrics_after.num_draft_tokens
|
||||
- spec_decode_metrics_before.num_draft_tokens
|
||||
)
|
||||
delta_accepted = (
|
||||
spec_decode_metrics_after.num_accepted_tokens
|
||||
- spec_decode_metrics_before.num_accepted_tokens
|
||||
)
|
||||
per_pos_rates: list[float] = []
|
||||
if delta_drafts > 0:
|
||||
positions = sorted(
|
||||
set(spec_decode_metrics_before.accepted_per_pos.keys())
|
||||
| set(spec_decode_metrics_after.accepted_per_pos.keys())
|
||||
)
|
||||
for pos in positions:
|
||||
before_val = spec_decode_metrics_before.accepted_per_pos.get(pos, 0)
|
||||
after_val = spec_decode_metrics_after.accepted_per_pos.get(
|
||||
pos, before_val
|
||||
)
|
||||
delta_pos = after_val - before_val
|
||||
per_pos_rates.append(delta_pos / delta_drafts)
|
||||
|
||||
if delta_draft_tokens > 0:
|
||||
acceptance_rate = (delta_accepted / delta_draft_tokens) * 100
|
||||
acceptance_length = (
|
||||
1 + delta_accepted / delta_drafts if delta_drafts > 0 else 0.0
|
||||
)
|
||||
spec_decode_stats = {
|
||||
"num_drafts": delta_drafts,
|
||||
"draft_tokens": delta_draft_tokens,
|
||||
"accepted_tokens": delta_accepted,
|
||||
"acceptance_rate": acceptance_rate,
|
||||
"acceptance_length": acceptance_length,
|
||||
"per_position_acceptance_rates": per_pos_rates,
|
||||
}
|
||||
|
||||
if task_type == TaskType.GENERATION:
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
@@ -863,6 +977,18 @@ async def benchmark(
|
||||
if rps_change_events:
|
||||
result["rps_change_events"] = rps_change_events
|
||||
|
||||
if spec_decode_stats is not None:
|
||||
result["spec_decode_acceptance_rate"] = spec_decode_stats["acceptance_rate"]
|
||||
result["spec_decode_acceptance_length"] = spec_decode_stats["acceptance_length"]
|
||||
result["spec_decode_num_drafts"] = int(spec_decode_stats["num_drafts"])
|
||||
result["spec_decode_draft_tokens"] = int(spec_decode_stats["draft_tokens"])
|
||||
result["spec_decode_accepted_tokens"] = int(
|
||||
spec_decode_stats["accepted_tokens"]
|
||||
)
|
||||
result["spec_decode_per_position_acceptance_rates"] = spec_decode_stats.get(
|
||||
"per_position_acceptance_rates", []
|
||||
)
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
@@ -908,6 +1034,35 @@ async def benchmark(
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
|
||||
if spec_decode_stats is not None:
|
||||
print("{s:{c}^{n}}".format(s="Speculative Decoding", n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Acceptance rate (%):", spec_decode_stats["acceptance_rate"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Acceptance length:", spec_decode_stats["acceptance_length"]
|
||||
)
|
||||
)
|
||||
print("{:<40} {:<10}".format("Drafts:", int(spec_decode_stats["num_drafts"])))
|
||||
print(
|
||||
"{:<40} {:<10}".format(
|
||||
"Draft tokens:", int(spec_decode_stats["draft_tokens"])
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10}".format(
|
||||
"Accepted tokens:", int(spec_decode_stats["accepted_tokens"])
|
||||
)
|
||||
)
|
||||
per_pos = spec_decode_stats.get("per_position_acceptance_rates", [])
|
||||
if per_pos:
|
||||
print("Per-position acceptance (%):")
|
||||
for i, rate in enumerate(per_pos):
|
||||
print("{:<40} {:<10.2f}".format(f" Position {i}:", rate * 100))
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
if profile:
|
||||
|
||||
Reference in New Issue
Block a user