[Misc] Fix up attention benchmarks (#33810)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-02-09 06:42:03 -08:00
committed by GitHub
parent 9562912cea
commit d0d97e2974
5 changed files with 218 additions and 94 deletions

View File

@@ -229,3 +229,40 @@ def get_batch_stats(requests: list[BatchRequest]) -> dict:
sum(r.kv_len for r in requests) / len(requests) if requests else 0
),
}
def get_batch_type(batch_spec: str, spec_decode_threshold: int = 8) -> str:
"""
Classify a batch spec into a type string.
Args:
batch_spec: Batch specification string (e.g., "q2k", "8q1s1k", "2q2k_8q1s1k")
spec_decode_threshold: Max q_len to be considered spec-decode vs extend
Returns:
Type string: "prefill", "decode", "spec-decode", "extend", or "mixed (types...)"
"""
requests = parse_batch_spec(batch_spec)
# Classify each request
types_present = set()
for req in requests:
if req.is_decode:
types_present.add("decode")
elif req.is_prefill:
types_present.add("prefill")
elif req.is_extend:
# Distinguish spec-decode (small q_len) from extend (chunked prefill)
if req.q_len <= spec_decode_threshold:
types_present.add("spec-decode")
else:
types_present.add("extend")
if len(types_present) == 1:
return types_present.pop()
elif len(types_present) > 1:
# Sort for consistent output
sorted_types = sorted(types_present)
return f"mixed ({'+'.join(sorted_types)})"
else:
return "unknown"

View File

@@ -12,6 +12,7 @@ from typing import Any
import numpy as np
import torch
from batch_spec import get_batch_type, parse_batch_spec
from rich.console import Console
from rich.table import Table
@@ -316,12 +317,14 @@ class ResultsFormatter:
backends: List of backend names being compared
compare_to_fastest: Show percentage comparison to fastest
"""
# Group by batch spec
# Group by batch spec, preserving first-occurrence order
by_spec = {}
specs_order = []
for r in results:
spec = r.config.batch_spec
if spec not in by_spec:
by_spec[spec] = {}
specs_order.append(spec)
by_spec[spec][r.config.backend] = r
# Create shortened backend names for display
@@ -337,6 +340,8 @@ class ResultsFormatter:
table = Table(title="Attention Benchmark Results")
table.add_column("Batch\nSpec", no_wrap=True)
table.add_column("Type", no_wrap=True)
table.add_column("Batch\nSize", justify="right", no_wrap=True)
multi = len(backends) > 1
for backend in backends:
@@ -350,12 +355,14 @@ class ResultsFormatter:
table.add_column(col_rel, justify="right", no_wrap=False)
# Add rows
for spec in sorted(by_spec.keys()):
for spec in specs_order:
spec_results = by_spec[spec]
times = {b: r.mean_time for b, r in spec_results.items() if r.success}
best_time = min(times.values()) if times else 0.0
row = [spec]
batch_type = get_batch_type(spec)
batch_size = len(parse_batch_spec(spec))
row = [spec, batch_type, str(batch_size)]
for backend in backends:
if backend in spec_results:
r = spec_results[backend]

View File

@@ -25,10 +25,18 @@ batch_specs:
- "4q1k_16q1s2k" # 4 prefill + 16 decode
- "2q4k_32q1s1k" # 2 large prefill + 32 decode
# Context extension
- "q1ks2k" # 1k query, 2k sequence (chunked prefill)
# Speculative decode (q <= 8)
- "16q2s1k" # 16 requests, 2 spec tokens, 1k KV cache
- "16q4s1k" # 16 requests, 4 spec tokens, 1k KV cache
- "16q8s1k" # 16 requests, 8 spec tokens, 1k KV cache
- "32q4s2k" # 32 requests, 4 spec tokens, 2k KV cache
- "8q8s4k" # 8 requests, 8 spec tokens, 4k KV cache
# Context extension (chunked prefill)
- "q1ks2k" # 1k query, 2k sequence
- "2q1ks4k" # 2 requests: 1k query, 4k sequence
# Available backends: flash, triton, flashinfer
backends:
- flash
- triton

View File

@@ -8,7 +8,9 @@ This module provides helpers for running standard attention backends
(FlashAttention, Triton, FlashInfer) with real vLLM integration.
"""
import logging
import types
from contextlib import contextmanager
import numpy as np
import torch
@@ -24,8 +26,13 @@ from vllm.config import (
ParallelConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
get_kv_cache_layout,
set_kv_cache_layout,
)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
# ============================================================================
@@ -37,22 +44,14 @@ _BACKEND_CONFIG = {
"flash": {
"module": "vllm.v1.attention.backends.flash_attn",
"backend_class": "FlashAttentionBackend",
"dtype": torch.float16,
"cache_layout": "standard",
# ^ [2, num_blocks, block_size, num_kv_heads, head_dim]
},
"triton": {
"module": "vllm.v1.attention.backends.triton_attn",
"backend_class": "TritonAttentionBackend",
"dtype": torch.float32,
"cache_layout": "standard",
},
"flashinfer": {
"module": "vllm.v1.attention.backends.flashinfer",
"backend_class": "FlashInferBackend",
"dtype": torch.float16,
"cache_layout": "flashinfer",
# ^ [num_blocks, 2, block_size, num_kv_heads, head_dim]
},
}
@@ -66,6 +65,18 @@ def _get_backend_config(backend: str) -> dict:
return _BACKEND_CONFIG[backend]
@contextmanager
def log_warnings_and_errors_only():
"""Temporarily set vLLM logger to WARNING level."""
logger = logging.getLogger("vllm")
old_level = logger.level
logger.setLevel(logging.WARNING)
try:
yield
finally:
logger.setLevel(old_level)
# ============================================================================
# Metadata Building Helpers
# ============================================================================
@@ -88,11 +99,7 @@ def _build_common_attn_metadata(
query_start_loc_cpu = query_start_loc.cpu()
seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
seq_lens_cpu = seq_lens.cpu()
max_seq_len = int(seq_lens_cpu.max())
context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
max_seq_len = int(seq_lens.max().item())
max_blocks = (max(kv_lens) + block_size - 1) // block_size
num_blocks = batch_size * max_blocks
@@ -107,8 +114,6 @@ def _build_common_attn_metadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_size,
num_actual_tokens=total_tokens,
max_query_len=max_query_len,
@@ -121,7 +126,6 @@ def _build_common_attn_metadata(
def _create_vllm_config(
config: BenchmarkConfig,
dtype: torch.dtype,
max_num_blocks: int,
) -> VllmConfig:
"""Create a VllmConfig for benchmarking with mock model methods."""
@@ -129,7 +133,7 @@ def _create_vllm_config(
model="meta-llama/Meta-Llama-3-8B",
tokenizer="meta-llama/Meta-Llama-3-8B",
trust_remote_code=False,
dtype=dtype,
dtype="auto", # Use model's native dtype
seed=0,
max_model_len=1024,
)
@@ -198,6 +202,7 @@ def _create_backend_impl(
backend_cfg: dict,
config: BenchmarkConfig,
device: torch.device,
dtype: torch.dtype,
):
"""Create backend implementation instance."""
import importlib
@@ -206,7 +211,6 @@ def _create_backend_impl(
backend_class = getattr(backend_module, backend_cfg["backend_class"])
scale = get_attention_scale(config.head_dim)
dtype = backend_cfg["dtype"]
impl = backend_class.get_impl_cls()(
num_heads=config.num_q_heads,
@@ -227,7 +231,7 @@ def _create_backend_impl(
layer = MockLayer(device, kv_cache_spec=kv_cache_spec)
return backend_class, impl, layer, dtype
return backend_class, impl, layer
def _create_metadata_builder(
@@ -235,11 +239,44 @@ def _create_metadata_builder(
kv_cache_spec: FullAttentionSpec,
vllm_config: VllmConfig,
device: torch.device,
backend_name: str = "",
):
"""Create metadata builder instance."""
return backend_class.get_builder_cls()(
layer_names = ["layer_0"]
builder_cls = backend_class.get_builder_cls()
# Flashinfer needs get_per_layer_parameters mocked since we don't have
# real model layers registered
if backend_name == "flashinfer":
import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
head_size = vllm_config.model_config.get_head_size()
return {
layer_name: PerLayerParameters(
window_left=-1, # No sliding window
logits_soft_cap=0.0, # No soft cap
sm_scale=1.0 / (head_size**0.5), # Standard scale
)
for layer_name in layer_names
}
with unittest.mock.patch(
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
mock_get_per_layer_parameters,
):
return builder_cls(
kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device,
)
return builder_cls(
kv_cache_spec=kv_cache_spec,
layer_names=["layer_0"],
layer_names=layer_names,
vllm_config=vllm_config,
device=device,
)
@@ -281,39 +318,44 @@ def _create_input_tensors(
def _create_kv_cache(
config: BenchmarkConfig,
max_num_blocks: int,
cache_layout: str,
backend_class,
device: torch.device,
dtype: torch.dtype,
) -> list:
"""Create KV cache tensors for all layers."""
if cache_layout == "flashinfer":
# FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
max_num_blocks,
2,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
)
for _ in range(config.num_layers)
]
else:
# Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
2,
max_num_blocks,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
)
for _ in range(config.num_layers)
]
"""Create KV cache tensors for all layers using the backend's methods.
Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order()
to create the cache with the correct shape and memory layout.
"""
# Get the logical shape from the backend
cache_shape = backend_class.get_kv_cache_shape(
num_blocks=max_num_blocks,
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
)
# Get the stride order for custom memory layout
try:
stride_order = backend_class.get_kv_cache_stride_order()
assert len(stride_order) == len(cache_shape)
except (AttributeError, NotImplementedError):
stride_order = tuple(range(len(cache_shape)))
# Permute shape to physical layout order
physical_shape = tuple(cache_shape[i] for i in stride_order)
# Compute inverse permutation to get back to logical view
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
cache_list = []
for _ in range(config.num_layers):
# Allocate in physical layout order (contiguous in memory)
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
# Permute to logical view
cache = cache.permute(*inv_order)
cache_list.append(cache)
return cache_list
@@ -418,53 +460,72 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens)
max_kv = max(kv_lens)
batch_size = len(q_lens)
max_num_blocks = (max_kv + config.block_size - 1) // config.block_size
# Calculate total blocks needed: batch_size * max_blocks_per_request
max_blocks_per_request = (max_kv + config.block_size - 1) // config.block_size
max_num_blocks = batch_size * max_blocks_per_request
backend_class, impl, layer, dtype = _create_backend_impl(
backend_cfg, config, device
)
# Suppress vLLM logs during setup to reduce spam
with log_warnings_and_errors_only():
# Create vllm_config first - uses model's native dtype via "auto"
vllm_config = _create_vllm_config(config, max_num_blocks)
dtype = vllm_config.model_config.dtype
common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device
)
# Wrap everything in set_current_vllm_config context
# This is required for backends like flashinfer that need global config
with set_current_vllm_config(vllm_config):
backend_class, impl, layer = _create_backend_impl(
backend_cfg, config, device, dtype
)
kv_cache_spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)
# Set KV cache layout if the backend requires a specific one
# (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention)
required_layout = backend_class.get_required_kv_cache_layout()
if required_layout is not None:
set_kv_cache_layout(required_layout)
get_kv_cache_layout.cache_clear()
vllm_config = _create_vllm_config(config, dtype, max_num_blocks)
common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device
)
builder = _create_metadata_builder(
backend_class, kv_cache_spec, vllm_config, device
)
kv_cache_spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_metadata,
)
builder = _create_metadata_builder(
backend_class, kv_cache_spec, vllm_config, device, config.backend
)
q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_metadata,
)
cache_list = _create_kv_cache(
config, max_num_blocks, backend_cfg["cache_layout"], device, dtype
)
q_list, k_list, v_list = _create_input_tensors(
config, total_q, device, dtype
)
times, mem_stats = _run_single_benchmark(
config,
impl,
layer,
q_list,
k_list,
v_list,
cache_list,
attn_metadata,
device,
dtype,
)
cache_list = _create_kv_cache(
config, max_num_blocks, backend_class, device, dtype
)
times, mem_stats = _run_single_benchmark(
config,
impl,
layer,
q_list,
k_list,
v_list,
cache_list,
attn_metadata,
device,
dtype,
)
mean_time = np.mean(times)
throughput = total_q / mean_time if mean_time > 0 else 0