[Misc] Fix up attention benchmarks (#33810)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-02-09 06:42:03 -08:00
committed by GitHub
parent 9562912cea
commit d0d97e2974
5 changed files with 218 additions and 94 deletions

View File

@@ -12,6 +12,7 @@ from typing import Any
import numpy as np
import torch
from batch_spec import get_batch_type, parse_batch_spec
from rich.console import Console
from rich.table import Table
@@ -316,12 +317,14 @@ class ResultsFormatter:
backends: List of backend names being compared
compare_to_fastest: Show percentage comparison to fastest
"""
# Group by batch spec
# Group by batch spec, preserving first-occurrence order
by_spec = {}
specs_order = []
for r in results:
spec = r.config.batch_spec
if spec not in by_spec:
by_spec[spec] = {}
specs_order.append(spec)
by_spec[spec][r.config.backend] = r
# Create shortened backend names for display
@@ -337,6 +340,8 @@ class ResultsFormatter:
table = Table(title="Attention Benchmark Results")
table.add_column("Batch\nSpec", no_wrap=True)
table.add_column("Type", no_wrap=True)
table.add_column("Batch\nSize", justify="right", no_wrap=True)
multi = len(backends) > 1
for backend in backends:
@@ -350,12 +355,14 @@ class ResultsFormatter:
table.add_column(col_rel, justify="right", no_wrap=False)
# Add rows
for spec in sorted(by_spec.keys()):
for spec in specs_order:
spec_results = by_spec[spec]
times = {b: r.mean_time for b, r in spec_results.items() if r.success}
best_time = min(times.values()) if times else 0.0
row = [spec]
batch_type = get_batch_type(spec)
batch_size = len(parse_batch_spec(spec))
row = [spec, batch_type, str(batch_size)]
for backend in backends:
if backend in spec_results:
r = spec_results[backend]