[Misc] Fix up attention benchmarks (#33810)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -229,3 +229,40 @@ def get_batch_stats(requests: list[BatchRequest]) -> dict:
|
||||
sum(r.kv_len for r in requests) / len(requests) if requests else 0
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_batch_type(batch_spec: str, spec_decode_threshold: int = 8) -> str:
|
||||
"""
|
||||
Classify a batch spec into a type string.
|
||||
|
||||
Args:
|
||||
batch_spec: Batch specification string (e.g., "q2k", "8q1s1k", "2q2k_8q1s1k")
|
||||
spec_decode_threshold: Max q_len to be considered spec-decode vs extend
|
||||
|
||||
Returns:
|
||||
Type string: "prefill", "decode", "spec-decode", "extend", or "mixed (types...)"
|
||||
"""
|
||||
requests = parse_batch_spec(batch_spec)
|
||||
|
||||
# Classify each request
|
||||
types_present = set()
|
||||
for req in requests:
|
||||
if req.is_decode:
|
||||
types_present.add("decode")
|
||||
elif req.is_prefill:
|
||||
types_present.add("prefill")
|
||||
elif req.is_extend:
|
||||
# Distinguish spec-decode (small q_len) from extend (chunked prefill)
|
||||
if req.q_len <= spec_decode_threshold:
|
||||
types_present.add("spec-decode")
|
||||
else:
|
||||
types_present.add("extend")
|
||||
|
||||
if len(types_present) == 1:
|
||||
return types_present.pop()
|
||||
elif len(types_present) > 1:
|
||||
# Sort for consistent output
|
||||
sorted_types = sorted(types_present)
|
||||
return f"mixed ({'+'.join(sorted_types)})"
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from batch_spec import get_batch_type, parse_batch_spec
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
@@ -316,12 +317,14 @@ class ResultsFormatter:
|
||||
backends: List of backend names being compared
|
||||
compare_to_fastest: Show percentage comparison to fastest
|
||||
"""
|
||||
# Group by batch spec
|
||||
# Group by batch spec, preserving first-occurrence order
|
||||
by_spec = {}
|
||||
specs_order = []
|
||||
for r in results:
|
||||
spec = r.config.batch_spec
|
||||
if spec not in by_spec:
|
||||
by_spec[spec] = {}
|
||||
specs_order.append(spec)
|
||||
by_spec[spec][r.config.backend] = r
|
||||
|
||||
# Create shortened backend names for display
|
||||
@@ -337,6 +340,8 @@ class ResultsFormatter:
|
||||
|
||||
table = Table(title="Attention Benchmark Results")
|
||||
table.add_column("Batch\nSpec", no_wrap=True)
|
||||
table.add_column("Type", no_wrap=True)
|
||||
table.add_column("Batch\nSize", justify="right", no_wrap=True)
|
||||
|
||||
multi = len(backends) > 1
|
||||
for backend in backends:
|
||||
@@ -350,12 +355,14 @@ class ResultsFormatter:
|
||||
table.add_column(col_rel, justify="right", no_wrap=False)
|
||||
|
||||
# Add rows
|
||||
for spec in sorted(by_spec.keys()):
|
||||
for spec in specs_order:
|
||||
spec_results = by_spec[spec]
|
||||
times = {b: r.mean_time for b, r in spec_results.items() if r.success}
|
||||
best_time = min(times.values()) if times else 0.0
|
||||
|
||||
row = [spec]
|
||||
batch_type = get_batch_type(spec)
|
||||
batch_size = len(parse_batch_spec(spec))
|
||||
row = [spec, batch_type, str(batch_size)]
|
||||
for backend in backends:
|
||||
if backend in spec_results:
|
||||
r = spec_results[backend]
|
||||
|
||||
@@ -25,10 +25,18 @@ batch_specs:
|
||||
- "4q1k_16q1s2k" # 4 prefill + 16 decode
|
||||
- "2q4k_32q1s1k" # 2 large prefill + 32 decode
|
||||
|
||||
# Context extension
|
||||
- "q1ks2k" # 1k query, 2k sequence (chunked prefill)
|
||||
# Speculative decode (q <= 8)
|
||||
- "16q2s1k" # 16 requests, 2 spec tokens, 1k KV cache
|
||||
- "16q4s1k" # 16 requests, 4 spec tokens, 1k KV cache
|
||||
- "16q8s1k" # 16 requests, 8 spec tokens, 1k KV cache
|
||||
- "32q4s2k" # 32 requests, 4 spec tokens, 2k KV cache
|
||||
- "8q8s4k" # 8 requests, 8 spec tokens, 4k KV cache
|
||||
|
||||
# Context extension (chunked prefill)
|
||||
- "q1ks2k" # 1k query, 2k sequence
|
||||
- "2q1ks4k" # 2 requests: 1k query, 4k sequence
|
||||
|
||||
# Available backends: flash, triton, flashinfer
|
||||
backends:
|
||||
- flash
|
||||
- triton
|
||||
|
||||
@@ -8,7 +8,9 @@ This module provides helpers for running standard attention backends
|
||||
(FlashAttention, Triton, FlashInfer) with real vLLM integration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -24,8 +26,13 @@ from vllm.config import (
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout,
|
||||
set_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
# ============================================================================
|
||||
@@ -37,22 +44,14 @@ _BACKEND_CONFIG = {
|
||||
"flash": {
|
||||
"module": "vllm.v1.attention.backends.flash_attn",
|
||||
"backend_class": "FlashAttentionBackend",
|
||||
"dtype": torch.float16,
|
||||
"cache_layout": "standard",
|
||||
# ^ [2, num_blocks, block_size, num_kv_heads, head_dim]
|
||||
},
|
||||
"triton": {
|
||||
"module": "vllm.v1.attention.backends.triton_attn",
|
||||
"backend_class": "TritonAttentionBackend",
|
||||
"dtype": torch.float32,
|
||||
"cache_layout": "standard",
|
||||
},
|
||||
"flashinfer": {
|
||||
"module": "vllm.v1.attention.backends.flashinfer",
|
||||
"backend_class": "FlashInferBackend",
|
||||
"dtype": torch.float16,
|
||||
"cache_layout": "flashinfer",
|
||||
# ^ [num_blocks, 2, block_size, num_kv_heads, head_dim]
|
||||
},
|
||||
}
|
||||
|
||||
@@ -66,6 +65,18 @@ def _get_backend_config(backend: str) -> dict:
|
||||
return _BACKEND_CONFIG[backend]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_warnings_and_errors_only():
|
||||
"""Temporarily set vLLM logger to WARNING level."""
|
||||
logger = logging.getLogger("vllm")
|
||||
old_level = logger.level
|
||||
logger.setLevel(logging.WARNING)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.setLevel(old_level)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metadata Building Helpers
|
||||
# ============================================================================
|
||||
@@ -88,11 +99,7 @@ def _build_common_attn_metadata(
|
||||
query_start_loc_cpu = query_start_loc.cpu()
|
||||
|
||||
seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
|
||||
context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)]
|
||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||
max_seq_len = int(seq_lens.max().item())
|
||||
|
||||
max_blocks = (max(kv_lens) + block_size - 1) // block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
@@ -107,8 +114,6 @@ def _build_common_attn_metadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=total_tokens,
|
||||
max_query_len=max_query_len,
|
||||
@@ -121,7 +126,6 @@ def _build_common_attn_metadata(
|
||||
|
||||
def _create_vllm_config(
|
||||
config: BenchmarkConfig,
|
||||
dtype: torch.dtype,
|
||||
max_num_blocks: int,
|
||||
) -> VllmConfig:
|
||||
"""Create a VllmConfig for benchmarking with mock model methods."""
|
||||
@@ -129,7 +133,7 @@ def _create_vllm_config(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
tokenizer="meta-llama/Meta-Llama-3-8B",
|
||||
trust_remote_code=False,
|
||||
dtype=dtype,
|
||||
dtype="auto", # Use model's native dtype
|
||||
seed=0,
|
||||
max_model_len=1024,
|
||||
)
|
||||
@@ -198,6 +202,7 @@ def _create_backend_impl(
|
||||
backend_cfg: dict,
|
||||
config: BenchmarkConfig,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Create backend implementation instance."""
|
||||
import importlib
|
||||
@@ -206,7 +211,6 @@ def _create_backend_impl(
|
||||
backend_class = getattr(backend_module, backend_cfg["backend_class"])
|
||||
|
||||
scale = get_attention_scale(config.head_dim)
|
||||
dtype = backend_cfg["dtype"]
|
||||
|
||||
impl = backend_class.get_impl_cls()(
|
||||
num_heads=config.num_q_heads,
|
||||
@@ -227,7 +231,7 @@ def _create_backend_impl(
|
||||
|
||||
layer = MockLayer(device, kv_cache_spec=kv_cache_spec)
|
||||
|
||||
return backend_class, impl, layer, dtype
|
||||
return backend_class, impl, layer
|
||||
|
||||
|
||||
def _create_metadata_builder(
|
||||
@@ -235,11 +239,44 @@ def _create_metadata_builder(
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
backend_name: str = "",
|
||||
):
|
||||
"""Create metadata builder instance."""
|
||||
return backend_class.get_builder_cls()(
|
||||
layer_names = ["layer_0"]
|
||||
builder_cls = backend_class.get_builder_cls()
|
||||
|
||||
# Flashinfer needs get_per_layer_parameters mocked since we don't have
|
||||
# real model layers registered
|
||||
if backend_name == "flashinfer":
|
||||
import unittest.mock
|
||||
|
||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||
|
||||
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
return {
|
||||
layer_name: PerLayerParameters(
|
||||
window_left=-1, # No sliding window
|
||||
logits_soft_cap=0.0, # No soft cap
|
||||
sm_scale=1.0 / (head_size**0.5), # Standard scale
|
||||
)
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with unittest.mock.patch(
|
||||
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
|
||||
mock_get_per_layer_parameters,
|
||||
):
|
||||
return builder_cls(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
layer_names=layer_names,
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return builder_cls(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
layer_names=["layer_0"],
|
||||
layer_names=layer_names,
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
@@ -281,39 +318,44 @@ def _create_input_tensors(
|
||||
def _create_kv_cache(
|
||||
config: BenchmarkConfig,
|
||||
max_num_blocks: int,
|
||||
cache_layout: str,
|
||||
backend_class,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> list:
|
||||
"""Create KV cache tensors for all layers."""
|
||||
if cache_layout == "flashinfer":
|
||||
# FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim]
|
||||
cache_list = [
|
||||
torch.zeros(
|
||||
max_num_blocks,
|
||||
2,
|
||||
config.block_size,
|
||||
config.num_kv_heads,
|
||||
config.head_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
else:
|
||||
# Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim]
|
||||
cache_list = [
|
||||
torch.zeros(
|
||||
2,
|
||||
max_num_blocks,
|
||||
config.block_size,
|
||||
config.num_kv_heads,
|
||||
config.head_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
"""Create KV cache tensors for all layers using the backend's methods.
|
||||
|
||||
Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order()
|
||||
to create the cache with the correct shape and memory layout.
|
||||
"""
|
||||
# Get the logical shape from the backend
|
||||
cache_shape = backend_class.get_kv_cache_shape(
|
||||
num_blocks=max_num_blocks,
|
||||
block_size=config.block_size,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
head_size=config.head_dim,
|
||||
)
|
||||
|
||||
# Get the stride order for custom memory layout
|
||||
try:
|
||||
stride_order = backend_class.get_kv_cache_stride_order()
|
||||
assert len(stride_order) == len(cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
stride_order = tuple(range(len(cache_shape)))
|
||||
|
||||
# Permute shape to physical layout order
|
||||
physical_shape = tuple(cache_shape[i] for i in stride_order)
|
||||
|
||||
# Compute inverse permutation to get back to logical view
|
||||
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
|
||||
|
||||
cache_list = []
|
||||
for _ in range(config.num_layers):
|
||||
# Allocate in physical layout order (contiguous in memory)
|
||||
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
|
||||
# Permute to logical view
|
||||
cache = cache.permute(*inv_order)
|
||||
cache_list.append(cache)
|
||||
|
||||
return cache_list
|
||||
|
||||
|
||||
@@ -418,53 +460,72 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
kv_lens = [r.kv_len for r in requests]
|
||||
total_q = sum(q_lens)
|
||||
max_kv = max(kv_lens)
|
||||
batch_size = len(q_lens)
|
||||
|
||||
max_num_blocks = (max_kv + config.block_size - 1) // config.block_size
|
||||
# Calculate total blocks needed: batch_size * max_blocks_per_request
|
||||
max_blocks_per_request = (max_kv + config.block_size - 1) // config.block_size
|
||||
max_num_blocks = batch_size * max_blocks_per_request
|
||||
|
||||
backend_class, impl, layer, dtype = _create_backend_impl(
|
||||
backend_cfg, config, device
|
||||
)
|
||||
# Suppress vLLM logs during setup to reduce spam
|
||||
with log_warnings_and_errors_only():
|
||||
# Create vllm_config first - uses model's native dtype via "auto"
|
||||
vllm_config = _create_vllm_config(config, max_num_blocks)
|
||||
dtype = vllm_config.model_config.dtype
|
||||
|
||||
common_metadata = _build_common_attn_metadata(
|
||||
q_lens, kv_lens, config.block_size, device
|
||||
)
|
||||
# Wrap everything in set_current_vllm_config context
|
||||
# This is required for backends like flashinfer that need global config
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend_class, impl, layer = _create_backend_impl(
|
||||
backend_cfg, config, device, dtype
|
||||
)
|
||||
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=config.block_size,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
head_size=config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
# Set KV cache layout if the backend requires a specific one
|
||||
# (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention)
|
||||
required_layout = backend_class.get_required_kv_cache_layout()
|
||||
if required_layout is not None:
|
||||
set_kv_cache_layout(required_layout)
|
||||
get_kv_cache_layout.cache_clear()
|
||||
|
||||
vllm_config = _create_vllm_config(config, dtype, max_num_blocks)
|
||||
common_metadata = _build_common_attn_metadata(
|
||||
q_lens, kv_lens, config.block_size, device
|
||||
)
|
||||
|
||||
builder = _create_metadata_builder(
|
||||
backend_class, kv_cache_spec, vllm_config, device
|
||||
)
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=config.block_size,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
head_size=config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
builder = _create_metadata_builder(
|
||||
backend_class, kv_cache_spec, vllm_config, device, config.backend
|
||||
)
|
||||
|
||||
q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
|
||||
cache_list = _create_kv_cache(
|
||||
config, max_num_blocks, backend_cfg["cache_layout"], device, dtype
|
||||
)
|
||||
q_list, k_list, v_list = _create_input_tensors(
|
||||
config, total_q, device, dtype
|
||||
)
|
||||
|
||||
times, mem_stats = _run_single_benchmark(
|
||||
config,
|
||||
impl,
|
||||
layer,
|
||||
q_list,
|
||||
k_list,
|
||||
v_list,
|
||||
cache_list,
|
||||
attn_metadata,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
cache_list = _create_kv_cache(
|
||||
config, max_num_blocks, backend_class, device, dtype
|
||||
)
|
||||
|
||||
times, mem_stats = _run_single_benchmark(
|
||||
config,
|
||||
impl,
|
||||
layer,
|
||||
q_list,
|
||||
k_list,
|
||||
v_list,
|
||||
cache_list,
|
||||
attn_metadata,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
|
||||
mean_time = np.mean(times)
|
||||
throughput = total_q / mean_time if mean_time > 0 else 0
|
||||
|
||||
Reference in New Issue
Block a user