476 lines
16 KiB
Python
476 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
"""Common utilities for attention benchmarking."""
|
|
|
|
import csv
|
|
import json
|
|
import math
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import torch
|
|
from batch_spec import get_batch_type, parse_batch_spec
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
|
|
|
|
def batch_spec_sort_key(spec: str) -> tuple[int, int, int]:
|
|
"""
|
|
Extract sorting key from batch spec: (batch_size, max_q_len, max_kv_len).
|
|
|
|
This ensures results are sorted by batch size first, then query length,
|
|
then sequence length, rather than alphabetically.
|
|
"""
|
|
try:
|
|
requests = parse_batch_spec(spec)
|
|
batch_size = len(requests)
|
|
max_q_len = max(r.q_len for r in requests) if requests else 0
|
|
max_kv_len = max(r.kv_len for r in requests) if requests else 0
|
|
return (batch_size, max_q_len, max_kv_len)
|
|
except Exception:
|
|
# Fallback for unparseable specs
|
|
return (0, 0, 0)
|
|
|
|
|
|
# Mock classes for vLLM attention infrastructure
|
|
|
|
|
|
class MockHfConfig:
|
|
"""Mock HuggingFace config that satisfies vLLM's requirements."""
|
|
|
|
def __init__(self, mla_dims: dict, index_topk: int | None = None):
|
|
self.num_attention_heads = mla_dims["num_q_heads"]
|
|
self.num_key_value_heads = mla_dims["num_kv_heads"]
|
|
self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"]
|
|
self.model_type = "deepseek_v2"
|
|
self.is_encoder_decoder = False
|
|
self.kv_lora_rank = mla_dims["kv_lora_rank"]
|
|
self.qk_nope_head_dim = mla_dims["qk_nope_head_dim"]
|
|
self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"]
|
|
self.v_head_dim = mla_dims["v_head_dim"]
|
|
self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]
|
|
if index_topk is not None:
|
|
self.index_topk = index_topk
|
|
|
|
def get_text_config(self):
|
|
return self
|
|
|
|
|
|
# Import AttentionLayerBase at module level to avoid circular dependencies
|
|
try:
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
except ImportError:
|
|
AttentionLayerBase = object # Fallback
|
|
|
|
|
|
class MockKVBProj:
|
|
"""Mock KV projection layer for MLA prefill mode.
|
|
|
|
Mimics ColumnParallelLinear behavior for kv_b_proj in MLA backends.
|
|
Projects kv_c_normed to [qk_nope_head_dim + v_head_dim] per head.
|
|
"""
|
|
|
|
def __init__(self, num_heads: int, qk_nope_head_dim: int, v_head_dim: int):
|
|
self.num_heads = num_heads
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
self.out_dim = qk_nope_head_dim + v_head_dim
|
|
|
|
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]:
|
|
"""
|
|
Project kv_c_normed to output space.
|
|
|
|
Args:
|
|
x: Input tensor [num_tokens, kv_lora_rank]
|
|
|
|
Returns:
|
|
Tuple containing output tensor
|
|
[num_tokens, num_heads, qk_nope_head_dim + v_head_dim]
|
|
"""
|
|
num_tokens = x.shape[0]
|
|
result = torch.randn(
|
|
num_tokens,
|
|
self.num_heads,
|
|
self.out_dim,
|
|
device=x.device,
|
|
dtype=x.dtype,
|
|
)
|
|
return (result,) # Return as tuple to match ColumnParallelLinear API
|
|
|
|
|
|
class MockIndexer:
|
|
"""Mock Indexer for sparse MLA backends.
|
|
|
|
Provides topk_indices_buffer that sparse MLA backends use to determine
|
|
which KV cache slots to attend to for each token.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_num_tokens: int,
|
|
topk_tokens: int,
|
|
device: torch.device,
|
|
):
|
|
self.topk_tokens = topk_tokens
|
|
self.topk_indices_buffer = torch.zeros(
|
|
(max_num_tokens, topk_tokens),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
def fill_random_indices(self, num_tokens: int, max_kv_len: int):
|
|
"""Fill topk_indices_buffer with random valid indices for benchmarking."""
|
|
indices = torch.randint(
|
|
0,
|
|
max_kv_len,
|
|
(num_tokens, self.topk_tokens),
|
|
dtype=torch.int32,
|
|
device=self.topk_indices_buffer.device,
|
|
)
|
|
self.topk_indices_buffer[:num_tokens] = indices
|
|
|
|
|
|
class MockLayer(AttentionLayerBase):
|
|
"""Mock attention layer with scale parameters and impl.
|
|
|
|
Inherits from AttentionLayerBase so it passes isinstance checks
|
|
in get_layers_from_vllm_config when FlashInfer prefill is enabled.
|
|
"""
|
|
|
|
def __init__(self, device: torch.device, impl=None, kv_cache_spec=None):
|
|
# Don't call super().__init__() as AttentionLayerBase doesn't have __init__
|
|
self._k_scale = torch.tensor(1.0, device=device)
|
|
self._v_scale = torch.tensor(1.0, device=device)
|
|
self._q_scale = torch.tensor(1.0, device=device)
|
|
# Scalar floats for kernels that need them
|
|
self._k_scale_float = float(self._k_scale.item())
|
|
self._v_scale_float = float(self._v_scale.item())
|
|
self._q_scale_float = float(self._q_scale.item())
|
|
# AttentionImpl for metadata builders to query
|
|
self.impl = impl
|
|
# KV cache spec for get_kv_cache_spec
|
|
self._kv_cache_spec = kv_cache_spec
|
|
|
|
def get_attn_backend(self):
|
|
"""Get the attention backend class (required by AttentionLayerBase)."""
|
|
# Return None as this is just a mock layer for benchmarking
|
|
return None
|
|
|
|
def get_kv_cache_spec(self):
|
|
"""Get the KV cache spec (required by AttentionLayerBase)."""
|
|
return self._kv_cache_spec
|
|
|
|
|
|
@dataclass
|
|
class ParameterSweep:
|
|
"""Configuration for sweeping a backend parameter."""
|
|
|
|
param_name: str # Name of the backend parameter to sweep
|
|
values: list[Any] # List of values to test
|
|
include_auto: bool = False # Also test with param unset (auto mode)
|
|
label_format: str = "{backend}_{param_name}_{value}" # Result label template
|
|
|
|
def get_label(self, backend: str, value: Any) -> str:
|
|
"""Generate a label for a specific parameter value."""
|
|
return self.label_format.format(
|
|
backend=backend, param_name=self.param_name, value=value
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ModelParameterSweep:
|
|
"""Configuration for sweeping a model configuration parameter."""
|
|
|
|
param_name: str # Name of the model config parameter to sweep (e.g., "num_q_heads")
|
|
values: list[Any] # List of values to test
|
|
label_format: str = "{backend}_{param_name}_{value}" # Result label template
|
|
|
|
def get_label(self, backend: str, value: Any) -> str:
|
|
"""Generate a label for a specific parameter value."""
|
|
return self.label_format.format(
|
|
backend=backend, param_name=self.param_name, value=value
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkConfig:
|
|
"""Configuration for a single benchmark run."""
|
|
|
|
backend: str
|
|
batch_spec: str
|
|
num_layers: int
|
|
head_dim: int
|
|
num_q_heads: int
|
|
num_kv_heads: int
|
|
block_size: int
|
|
device: str
|
|
dtype: torch.dtype = torch.float16
|
|
repeats: int = 1
|
|
warmup_iters: int = 3
|
|
profile_memory: bool = False
|
|
use_cuda_graphs: bool = False
|
|
|
|
# MLA-specific
|
|
kv_lora_rank: int | None = None
|
|
qk_nope_head_dim: int | None = None
|
|
qk_rope_head_dim: int | None = None
|
|
v_head_dim: int | None = None
|
|
|
|
# Backend-specific tuning
|
|
num_kv_splits: int | None = None # CUTLASS MLA
|
|
reorder_batch_threshold: int | None = None # FlashAttn MLA, FlashMLA
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkResult:
|
|
"""Results from a single benchmark run."""
|
|
|
|
config: BenchmarkConfig
|
|
mean_time: float # seconds
|
|
std_time: float # seconds
|
|
min_time: float # seconds
|
|
max_time: float # seconds
|
|
throughput_tokens_per_sec: float | None = None
|
|
memory_allocated_mb: float | None = None
|
|
memory_reserved_mb: float | None = None
|
|
error: str | None = None
|
|
|
|
@property
|
|
def success(self) -> bool:
|
|
"""Whether benchmark completed successfully."""
|
|
return self.error is None
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary for serialization."""
|
|
return {
|
|
"config": asdict(self.config),
|
|
"mean_time": self.mean_time,
|
|
"std_time": self.std_time,
|
|
"min_time": self.min_time,
|
|
"max_time": self.max_time,
|
|
"throughput_tokens_per_sec": self.throughput_tokens_per_sec,
|
|
"memory_allocated_mb": self.memory_allocated_mb,
|
|
"memory_reserved_mb": self.memory_reserved_mb,
|
|
"error": self.error,
|
|
}
|
|
|
|
|
|
class ResultsFormatter:
|
|
"""Format and display benchmark results."""
|
|
|
|
def __init__(self, console: Console | None = None):
|
|
self.console = console or Console()
|
|
|
|
def print_table(
|
|
self,
|
|
results: list[BenchmarkResult],
|
|
backends: list[str],
|
|
compare_to_fastest: bool = True,
|
|
):
|
|
"""
|
|
Print results as a rich table.
|
|
|
|
Args:
|
|
results: List of BenchmarkResult
|
|
backends: List of backend names being compared
|
|
compare_to_fastest: Show percentage comparison to fastest
|
|
"""
|
|
# Group by batch spec, preserving first-occurrence order
|
|
by_spec = {}
|
|
specs_order = []
|
|
for r in results:
|
|
spec = r.config.batch_spec
|
|
if spec not in by_spec:
|
|
by_spec[spec] = {}
|
|
specs_order.append(spec)
|
|
by_spec[spec][r.config.backend] = r
|
|
|
|
# Sort specs by (batch_size, q_len, kv_len) instead of alphabetically
|
|
specs_order = sorted(by_spec.keys(), key=batch_spec_sort_key)
|
|
|
|
# Create shortened backend names for display
|
|
def shorten_backend_name(name: str) -> str:
|
|
"""Shorten long backend names for table display."""
|
|
# Remove common prefixes
|
|
name = name.replace("flashattn_mla", "famla")
|
|
name = name.replace("flashinfer_mla", "fimla")
|
|
name = name.replace("flashmla", "fmla")
|
|
name = name.replace("cutlass_mla", "cmla")
|
|
name = name.replace("numsplits", "ns")
|
|
return name
|
|
|
|
table = Table(title="Attention Benchmark Results")
|
|
table.add_column("Batch\nSpec", no_wrap=True)
|
|
table.add_column("Type", no_wrap=True)
|
|
table.add_column("Batch\nSize", justify="right", no_wrap=True)
|
|
|
|
multi = len(backends) > 1
|
|
for backend in backends:
|
|
short_name = shorten_backend_name(backend)
|
|
# Time column
|
|
col_time = f"{short_name}\nTime (s)"
|
|
table.add_column(col_time, justify="right", no_wrap=False)
|
|
if multi and compare_to_fastest:
|
|
# Relative performance column
|
|
col_rel = f"{short_name}\nvs Best"
|
|
table.add_column(col_rel, justify="right", no_wrap=False)
|
|
|
|
# Add rows
|
|
for spec in specs_order:
|
|
spec_results = by_spec[spec]
|
|
times = {b: r.mean_time for b, r in spec_results.items() if r.success}
|
|
best_time = min(times.values()) if times else 0.0
|
|
|
|
batch_type = get_batch_type(spec)
|
|
batch_size = len(parse_batch_spec(spec))
|
|
row = [spec, batch_type, str(batch_size)]
|
|
for backend in backends:
|
|
if backend in spec_results:
|
|
r = spec_results[backend]
|
|
if r.success:
|
|
row.append(f"{r.mean_time:.6f}")
|
|
if multi and compare_to_fastest:
|
|
pct = (
|
|
(r.mean_time / best_time * 100) if best_time > 0 else 0
|
|
)
|
|
pct_str = f"{pct:.1f}%"
|
|
if r.mean_time == best_time:
|
|
pct_str = f"[bold green]{pct_str}[/]"
|
|
row.append(pct_str)
|
|
else:
|
|
row.append("[red]ERROR[/]")
|
|
if multi and compare_to_fastest:
|
|
row.append("-")
|
|
else:
|
|
row.append("-")
|
|
if multi and compare_to_fastest:
|
|
row.append("-")
|
|
|
|
table.add_row(*row)
|
|
|
|
self.console.print(table)
|
|
|
|
def save_csv(self, results: list[BenchmarkResult], path: str):
|
|
"""Save results to CSV file."""
|
|
if not results:
|
|
return
|
|
|
|
path_obj = Path(path)
|
|
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(path, "w", newline="") as f:
|
|
writer = csv.DictWriter(
|
|
f,
|
|
fieldnames=[
|
|
"backend",
|
|
"batch_spec",
|
|
"num_layers",
|
|
"mean_time",
|
|
"std_time",
|
|
"throughput",
|
|
"memory_mb",
|
|
],
|
|
)
|
|
writer.writeheader()
|
|
for r in results:
|
|
writer.writerow(
|
|
{
|
|
"backend": r.config.backend,
|
|
"batch_spec": r.config.batch_spec,
|
|
"num_layers": r.config.num_layers,
|
|
"mean_time": r.mean_time,
|
|
"std_time": r.std_time,
|
|
"throughput": r.throughput_tokens_per_sec or 0,
|
|
"memory_mb": r.memory_allocated_mb or 0,
|
|
}
|
|
)
|
|
|
|
self.console.print(f"[green]Saved CSV results to {path}[/]")
|
|
|
|
def save_json(self, results: list[BenchmarkResult], path: str):
|
|
"""Save results to JSON file."""
|
|
path_obj = Path(path)
|
|
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
data = [r.to_dict() for r in results]
|
|
with open(path, "w") as f:
|
|
json.dump(data, f, indent=2, default=str)
|
|
|
|
self.console.print(f"[green]Saved JSON results to {path}[/]")
|
|
|
|
|
|
def setup_mla_dims(model_name: str = "deepseek-v3") -> dict:
|
|
"""
|
|
Get MLA dimensions for known models.
|
|
|
|
Args:
|
|
model_name: Model identifier
|
|
|
|
Returns:
|
|
Dict with MLA dimension configuration
|
|
"""
|
|
configs = {
|
|
"deepseek-v2": {
|
|
"kv_lora_rank": 512,
|
|
"qk_nope_head_dim": 128,
|
|
"qk_rope_head_dim": 64,
|
|
"v_head_dim": 128,
|
|
"num_q_heads": 128,
|
|
"num_kv_heads": 1,
|
|
"head_dim": 576,
|
|
},
|
|
"deepseek-v3": {
|
|
"kv_lora_rank": 512,
|
|
"qk_nope_head_dim": 128,
|
|
"qk_rope_head_dim": 64,
|
|
"v_head_dim": 128,
|
|
"num_q_heads": 128,
|
|
"num_kv_heads": 1,
|
|
"head_dim": 576,
|
|
},
|
|
"deepseek-v2-lite": {
|
|
"kv_lora_rank": 512,
|
|
"qk_nope_head_dim": 128,
|
|
"qk_rope_head_dim": 64,
|
|
"v_head_dim": 128,
|
|
"num_q_heads": 16,
|
|
"num_kv_heads": 1,
|
|
"head_dim": 576,
|
|
},
|
|
}
|
|
|
|
if model_name not in configs:
|
|
raise ValueError(
|
|
f"Unknown model '{model_name}'. Known models: {list(configs.keys())}"
|
|
)
|
|
|
|
return configs[model_name]
|
|
|
|
|
|
def get_attention_scale(head_dim: int) -> float:
|
|
"""Compute attention scale factor (1/sqrt(d))."""
|
|
return 1.0 / math.sqrt(head_dim)
|
|
|
|
|
|
def is_mla_backend(backend: str) -> bool:
|
|
"""
|
|
Check if backend is an MLA backend using the AttentionBackendEnum.
|
|
|
|
Args:
|
|
backend: Backend name matching AttentionBackendEnum exactly
|
|
(e.g., "FLASHMLA_SPARSE")
|
|
|
|
Returns:
|
|
True if the backend is an MLA backend, False otherwise
|
|
"""
|
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
|
|
|
try:
|
|
backend_enum = AttentionBackendEnum[backend]
|
|
backend_class = backend_enum.get_class()
|
|
return backend_class.is_mla()
|
|
except (KeyError, ValueError, ImportError, AttributeError):
|
|
return False
|