[Misc] Fix up attention benchmarks (#33810)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-02-09 06:42:03 -08:00
committed by GitHub
parent 9562912cea
commit d0d97e2974
5 changed files with 218 additions and 94 deletions

View File

@@ -229,3 +229,40 @@ def get_batch_stats(requests: list[BatchRequest]) -> dict:
sum(r.kv_len for r in requests) / len(requests) if requests else 0
),
}
def get_batch_type(batch_spec: str, spec_decode_threshold: int = 8) -> str:
"""
Classify a batch spec into a type string.
Args:
batch_spec: Batch specification string (e.g., "q2k", "8q1s1k", "2q2k_8q1s1k")
spec_decode_threshold: Max q_len to be considered spec-decode vs extend
Returns:
Type string: "prefill", "decode", "spec-decode", "extend", or "mixed (types...)"
"""
requests = parse_batch_spec(batch_spec)
# Classify each request
types_present = set()
for req in requests:
if req.is_decode:
types_present.add("decode")
elif req.is_prefill:
types_present.add("prefill")
elif req.is_extend:
# Distinguish spec-decode (small q_len) from extend (chunked prefill)
if req.q_len <= spec_decode_threshold:
types_present.add("spec-decode")
else:
types_present.add("extend")
if len(types_present) == 1:
return types_present.pop()
elif len(types_present) > 1:
# Sort for consistent output
sorted_types = sorted(types_present)
return f"mixed ({'+'.join(sorted_types)})"
else:
return "unknown"