Files
vllm/benchmarks/attention_benchmarks/mla_runner.py
Matthew Bonanni e82fa448c4 Add attention benchmarking tools (#26835)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Claude <noreply@anthropic.com>
2026-01-28 00:09:20 +00:00

837 lines
26 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
MLA benchmark runner - shared utilities for MLA benchmarks.
This module provides helpers for running MLA backends without
needing full VllmConfig integration.
"""
import importlib
import numpy as np
import torch
from batch_spec import parse_batch_spec
from common import (
BenchmarkResult,
MockHfConfig,
MockKVBProj,
MockLayer,
setup_mla_dims,
)
from vllm.config import (
CacheConfig,
CompilationConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
# ============================================================================
# VllmConfig Creation
# ============================================================================
def _add_mock_methods_to_model_config(model_config: ModelConfig) -> None:
"""
Add mock methods for layer-specific queries to ModelConfig.
These methods are needed by metadata builders but aren't normally
present on ModelConfig when used in benchmark contexts.
"""
import types
model_config.get_num_layers = types.MethodType(lambda self: 1, model_config)
model_config.get_sliding_window_for_layer = types.MethodType(
lambda self, _i: None, model_config
)
model_config.get_logits_soft_cap_for_layer = types.MethodType(
lambda self, _i: None, model_config
)
model_config.get_sm_scale_for_layer = types.MethodType(
lambda self, _i: 1.0 / model_config.get_head_size() ** 0.5, model_config
)
def create_minimal_vllm_config(
model_name: str = "deepseek-v3",
block_size: int = 128,
max_num_seqs: int = 256,
mla_dims: dict | None = None,
) -> VllmConfig:
"""
Create minimal VllmConfig for MLA benchmarks.
Args:
model_name: Model name (deepseek-v2, deepseek-v3, etc.) - used if mla_dims not
provided
block_size: KV cache block size
max_num_seqs: Maximum number of sequences
mla_dims: Optional custom MLA dimensions dict. If not provided, uses
setup_mla_dims(model_name)
Returns:
VllmConfig for benchmarking
"""
# Get MLA dimensions - use provided or load from model name
if mla_dims is None:
mla_dims = setup_mla_dims(model_name)
# Create mock HF config first (avoids downloading from HuggingFace)
mock_hf_config = MockHfConfig(mla_dims)
# Create a temporary minimal config.json to avoid HF downloads
# This ensures consistent ModelConfig construction without network access
import json
import os
import shutil
import tempfile
minimal_config = {
"architectures": ["DeepseekV2ForCausalLM"],
"model_type": "deepseek_v2",
"num_attention_heads": mla_dims["num_q_heads"],
"num_key_value_heads": mla_dims["num_kv_heads"],
"hidden_size": mla_dims["head_dim"] * mla_dims["num_q_heads"],
"torch_dtype": "bfloat16",
"max_position_embeddings": 163840, # DeepSeek V3 default
"rope_theta": 10000.0,
"vocab_size": 128256,
}
# Create temporary directory with config.json
temp_dir = tempfile.mkdtemp(prefix="vllm_bench_")
config_path = os.path.join(temp_dir, "config.json")
with open(config_path, "w") as f:
json.dump(minimal_config, f)
try:
# Create model config using local path - no HF downloads
model_config = ModelConfig(
model=temp_dir, # Use local temp directory
tokenizer=None,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
max_model_len=32768,
quantization=None,
quantization_param_path=None,
enforce_eager=False,
max_context_len_to_capture=None,
max_seq_len_to_capture=8192,
max_logprobs=20,
disable_sliding_window=False,
skip_tokenizer_init=True,
served_model_name=None,
limit_mm_per_prompt=None,
use_async_output_proc=True,
config_format="auto",
)
finally:
# Clean up temporary directory
shutil.rmtree(temp_dir, ignore_errors=True)
# Override with our mock config
model_config.hf_config = mock_hf_config
model_config.hf_text_config = mock_hf_config
# Add mock methods for layer-specific queries
_add_mock_methods_to_model_config(model_config)
# Create sub-configs
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=False,
)
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=8192,
max_model_len=32768,
is_encoder_decoder=False,
enable_chunked_prefill=True,
)
parallel_config = ParallelConfig(
tensor_parallel_size=1,
)
compilation_config = CompilationConfig()
return VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
compilation_config=compilation_config,
)
# ============================================================================
# Backend Configuration
# ============================================================================
# Backend name to class name prefix mapping
_BACKEND_NAME_MAP = {
"flashattn_mla": "FlashAttnMLA",
"flashmla": "FlashMLA",
"flashinfer_mla": "FlashInferMLA",
"cutlass_mla": "CutlassMLA",
}
# Special properties that differ from defaults
_BACKEND_PROPERTIES = {
"flashmla": {
"query_format": "concat", # Single concatenated tensor (vs tuple)
"block_size": 64, # FlashMLA uses fixed block size
},
"flashinfer_mla": {
"block_size": 64, # FlashInfer MLA only supports 32 or 64
},
}
def _get_backend_config(backend: str) -> dict:
"""
Get backend configuration using naming conventions.
All MLA backends follow the pattern:
- Module: vllm.v1.attention.backends.mla.{backend}
- Impl: {Name}Impl
- Metadata: {Name}Metadata (or MLACommonMetadata)
- DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata)
- MetadataBuilder: {Name}MetadataBuilder
"""
if backend not in _BACKEND_NAME_MAP:
raise ValueError(f"Unknown backend: {backend}")
name = _BACKEND_NAME_MAP[backend]
props = _BACKEND_PROPERTIES.get(backend, {})
# Check if backend uses common metadata (FlashInfer, CUTLASS)
uses_common = backend in ("flashinfer_mla", "cutlass_mla")
return {
"module": f"vllm.v1.attention.backends.mla.{backend}",
"impl_class": f"{name}Impl",
"metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata",
"decode_metadata_class": "MLACommonDecodeMetadata"
if uses_common
else f"{name}DecodeMetadata",
"builder_class": f"{name}MetadataBuilder",
"query_format": props.get("query_format", "tuple"),
"block_size": props.get("block_size", None),
}
# ============================================================================
# Metadata Building Helpers
# ============================================================================
def _build_attention_metadata(
requests: list,
block_size: int,
device: torch.device,
builder_instance,
) -> tuple:
"""
Build attention metadata from batch requests.
Args:
requests: List of BatchRequest objects
block_size: KV cache block size
device: Target device
builder_instance: Metadata builder instance
Returns:
Tuple of (metadata, kv_cache_num_blocks)
"""
q_lens = [r.q_len for r in requests]
kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens)
max_kv = max(kv_lens)
# Build query start locations
q_start_cpu = torch.tensor(
[0] + [sum(q_lens[: i + 1]) for i in range(len(q_lens))],
dtype=torch.int32,
)
q_start_gpu = q_start_cpu.to(device)
# Build sequence lengths
seq_lens_cpu = torch.tensor(kv_lens, dtype=torch.int32)
seq_lens_gpu = seq_lens_cpu.to(device)
# Build num_computed_tokens (context length for each request)
context_lens = [kv_len - q_len for q_len, kv_len in zip(q_lens, kv_lens)]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
# Build block table
num_blocks_per_req = [(kv + block_size - 1) // block_size for kv in kv_lens]
max_num_blocks = max(num_blocks_per_req)
block_table_cpu = np.zeros((len(requests), max_num_blocks), dtype=np.int32)
current_block = 0
for i, num_blocks in enumerate(num_blocks_per_req):
for j in range(num_blocks):
block_table_cpu[i, j] = current_block
current_block += 1
block_table_gpu = torch.from_numpy(block_table_cpu).to(device)
# Build slot mapping
slot_mapping_list = []
for i, (q_len, kv_len, num_blocks) in enumerate(
zip(q_lens, kv_lens, num_blocks_per_req)
):
context_len = kv_len - q_len
for j in range(q_len):
token_kv_idx = context_len + j
block_idx = token_kv_idx // block_size
offset_in_block = token_kv_idx % block_size
global_block_id = block_table_cpu[i, block_idx]
slot_id = global_block_id * block_size + offset_in_block
slot_mapping_list.append(slot_id)
slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64, device=device)
# Create CommonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
common_attn_metadata = CommonAttentionMetadata(
num_reqs=len(requests),
max_query_len=max(q_lens),
max_seq_len=max_kv,
num_actual_tokens=total_q,
query_start_loc=q_start_gpu,
query_start_loc_cpu=q_start_cpu,
seq_lens=seq_lens_gpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
slot_mapping=slot_mapping,
block_table_tensor=block_table_gpu,
dcp_local_seq_lens=None,
)
# Use the production build() method
metadata = builder_instance.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=False,
)
return metadata, current_block
def _create_input_tensors(
total_q: int,
mla_dims: dict,
query_format: str,
device: torch.device,
dtype: torch.dtype,
):
"""
Create input tensors for both decode and prefill modes.
MLA requires different tensor formats for decode vs prefill:
- Decode: Uses kv_lora_rank (512) dimension
- Prefill: Uses qk_nope_head_dim (128) to stay under FlashAttention's 256 limit
Args:
total_q: Total number of query tokens
mla_dims: MLA dimension configuration
query_format: Either "tuple" or "concat"
device: Target device
dtype: Tensor dtype
Returns:
Tuple of (decode_inputs, prefill_inputs)
- decode_inputs: Query tensor(s) for decode mode
- prefill_inputs: Dict with 'q', 'k_c_normed', 'k_pe', 'k_scale' for prefill
"""
if query_format == "tuple":
# Decode mode format: (q_nope, q_pe) where q_nope has kv_lora_rank dim
q_nope_decode = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["kv_lora_rank"],
device=device,
dtype=dtype,
)
q_pe = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["qk_rope_head_dim"],
device=device,
dtype=dtype,
)
decode_inputs = (q_nope_decode, q_pe)
# For prefill, we need q with qk_nope_head_dim instead of kv_lora_rank
q_nope_prefill = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["qk_nope_head_dim"],
device=device,
dtype=dtype,
)
prefill_q = torch.cat([q_nope_prefill, q_pe], dim=-1)
else: # concat
decode_inputs = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
device=device,
dtype=dtype,
)
# For prefill with concat format
prefill_q = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
device=device,
dtype=dtype,
)
# Create additional inputs needed for prefill forward
k_c_normed = torch.randn(
total_q,
mla_dims["kv_lora_rank"],
device=device,
dtype=dtype,
)
k_pe = torch.randn(
total_q,
1, # Single head for MLA
mla_dims["qk_rope_head_dim"],
device=device,
dtype=dtype,
)
k_scale = torch.ones(1, device=device, dtype=torch.float32)
output = torch.zeros(
total_q,
mla_dims["num_q_heads"] * mla_dims["v_head_dim"],
device=device,
dtype=dtype,
)
prefill_inputs = {
"q": prefill_q,
"k_c_normed": k_c_normed,
"k_pe": k_pe,
"k_scale": k_scale,
"output": output,
}
return decode_inputs, prefill_inputs
# ============================================================================
# Backend Initialization
# ============================================================================
def _create_backend_impl(
backend_cfg: dict,
mla_dims: dict,
vllm_config: VllmConfig,
device: torch.device,
):
"""
Create backend implementation instance.
Args:
backend_cfg: Backend configuration dict
mla_dims: MLA dimension configuration
vllm_config: VllmConfig instance
device: Target device
Returns:
Tuple of (impl, layer, builder_instance)
"""
# Import backend classes
backend_module = importlib.import_module(backend_cfg["module"])
impl_class = getattr(backend_module, backend_cfg["impl_class"])
# Calculate scale
scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"])
# Create mock kv_b_proj layer for prefill mode
mock_kv_b_proj = MockKVBProj(
num_heads=mla_dims["num_q_heads"],
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
v_head_dim=mla_dims["v_head_dim"],
)
# Create impl
impl = impl_class(
num_heads=mla_dims["num_q_heads"],
head_size=mla_dims["head_dim"],
scale=scale,
num_kv_heads=mla_dims["num_kv_heads"],
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=mla_dims["kv_lora_rank"],
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
v_head_dim=mla_dims["v_head_dim"],
kv_b_proj=mock_kv_b_proj,
)
# Initialize DCP attributes
if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1):
impl.dcp_world_size = 1
impl.dcp_rank = 0
# Create KV cache spec for MockLayer
from vllm.v1.kv_cache_interface import FullAttentionSpec
kv_cache_spec = FullAttentionSpec(
block_size=backend_cfg["block_size"] or vllm_config.cache_config.block_size,
num_kv_heads=1, # MLA uses 1 KV head
head_size=576, # MLA head dim
dtype=torch.bfloat16,
)
# Create mock layer
layer = MockLayer(device, impl=impl, kv_cache_spec=kv_cache_spec)
# Create builder instance if needed
builder_instance = None
if backend_cfg["builder_class"]:
builder_class = getattr(backend_module, backend_cfg["builder_class"])
# Populate static_forward_context so builder can find the layer
# MockLayer inherits from AttentionLayerBase, so isinstance checks pass
vllm_config.compilation_config.static_forward_context = {"placeholder": layer}
builder_instance = builder_class(
kv_cache_spec=kv_cache_spec,
layer_names=["placeholder"],
vllm_config=vllm_config,
device=device,
)
return impl, layer, builder_instance
# ============================================================================
# Config Helpers
# ============================================================================
def _extract_mla_dims_from_config(config) -> dict | None:
"""
Extract MLA dimensions from BenchmarkConfig if all required fields are present.
Args:
config: BenchmarkConfig instance
Returns:
Dict with MLA dimensions if all fields are provided, None otherwise
"""
# Check if all MLA-specific fields are provided
if all(
[
config.kv_lora_rank is not None,
config.qk_nope_head_dim is not None,
config.qk_rope_head_dim is not None,
config.v_head_dim is not None,
]
):
return {
"kv_lora_rank": config.kv_lora_rank,
"qk_nope_head_dim": config.qk_nope_head_dim,
"qk_rope_head_dim": config.qk_rope_head_dim,
"v_head_dim": config.v_head_dim,
"num_q_heads": config.num_q_heads,
"num_kv_heads": config.num_kv_heads,
"head_dim": config.head_dim,
}
# Fallback: if MLA fields not fully specified, try to construct from basic fields
elif config.head_dim == 576:
# This looks like a DeepSeek MLA config, use standard dimensions with custom
# head count
return {
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"num_q_heads": config.num_q_heads,
"num_kv_heads": config.num_kv_heads,
"head_dim": config.head_dim,
}
return None
# ============================================================================
# Benchmark Execution
# ============================================================================
def _run_single_benchmark(
config,
impl,
layer,
builder_instance,
backend_cfg: dict,
mla_dims: dict,
device: torch.device,
) -> BenchmarkResult:
"""
Run a single benchmark iteration.
Args:
config: BenchmarkConfig instance
impl: Backend implementation instance
layer: MockLayer instance
builder_instance: Metadata builder instance
backend_cfg: Backend configuration dict
mla_dims: MLA dimension configuration
device: Target device
Returns:
BenchmarkResult with timing statistics
"""
# Parse batch spec
requests = parse_batch_spec(config.batch_spec)
q_lens = [r.q_len for r in requests]
total_q = sum(q_lens)
# Determine block size
block_size = backend_cfg["block_size"] or config.block_size
# Build metadata
metadata, num_blocks = _build_attention_metadata(
requests, block_size, device, builder_instance
)
# Create KV cache
kv_cache = torch.zeros(
num_blocks,
block_size,
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
device=device,
dtype=torch.bfloat16,
)
# Create input tensors for both decode and prefill modes
decode_inputs, prefill_inputs = _create_input_tensors(
total_q,
mla_dims,
backend_cfg["query_format"],
device,
torch.bfloat16,
)
# Determine which forward method to use based on metadata
if metadata.decode is not None:
forward_fn = lambda: impl._forward_decode(
decode_inputs, kv_cache, metadata, layer
)
elif metadata.prefill is not None:
forward_fn = lambda: impl._forward_prefill(
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
kv_cache,
metadata,
prefill_inputs["k_scale"],
prefill_inputs["output"],
)
else:
raise RuntimeError("Metadata has neither decode nor prefill metadata")
# Warmup
for _ in range(config.warmup_iters):
forward_fn()
torch.cuda.synchronize()
# Benchmark
times = []
for _ in range(config.repeats):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(config.num_layers):
forward_fn()
end.record()
torch.cuda.synchronize()
elapsed_ms = start.elapsed_time(end)
times.append(elapsed_ms / 1000.0 / config.num_layers)
mean_time = float(np.mean(times))
return BenchmarkResult(
config=config,
mean_time=mean_time,
std_time=float(np.std(times)),
min_time=float(np.min(times)),
max_time=float(np.max(times)),
throughput_tokens_per_sec=total_q / mean_time if mean_time > 0 else 0,
)
def _run_mla_benchmark_batched(
backend: str,
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
) -> list[BenchmarkResult]:
"""
Unified batched MLA benchmark runner for all backends.
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
This function reuses backend initialization across multiple benchmarks
to avoid setup/teardown overhead.
Args:
backend: Backend name
configs_with_params: List of (config, threshold, num_splits) tuples
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
- num_splits: num_kv_splits (CUTLASS only)
Returns:
List of BenchmarkResult objects
"""
if not configs_with_params:
return []
backend_cfg = _get_backend_config(backend)
device = torch.device(configs_with_params[0][0].device)
torch.cuda.set_device(device)
# Determine block size
config_block_size = configs_with_params[0][0].block_size
block_size = backend_cfg["block_size"] or config_block_size
# Extract MLA dimensions from the first config
first_config = configs_with_params[0][0]
mla_dims = _extract_mla_dims_from_config(first_config)
# If config didn't provide MLA dims, fall back to default model
if mla_dims is None:
mla_dims = setup_mla_dims("deepseek-v3")
# Create and set vLLM config for MLA (reused across all benchmarks)
vllm_config = create_minimal_vllm_config(
model_name="deepseek-v3", # Used only for model path
block_size=block_size,
mla_dims=mla_dims, # Use custom dims from config or default
)
results = []
with set_current_vllm_config(vllm_config):
# Create backend impl, layer, and builder (reused across benchmarks)
impl, layer, builder_instance = _create_backend_impl(
backend_cfg, mla_dims, vllm_config, device
)
# Run each benchmark with the shared impl
for config, threshold, num_splits in configs_with_params:
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
original_threshold = None
if threshold is not None and builder_instance:
original_threshold = builder_instance.reorder_batch_threshold
builder_instance.reorder_batch_threshold = threshold
# Set num_splits for CUTLASS
original_num_splits = None
if num_splits is not None and hasattr(impl, "_num_kv_splits"):
original_num_splits = impl._num_kv_splits
impl._num_kv_splits = num_splits
try:
result = _run_single_benchmark(
config,
impl,
layer,
builder_instance,
backend_cfg,
mla_dims,
device,
)
results.append(result)
finally:
# Restore original threshold
if original_threshold is not None:
builder_instance.reorder_batch_threshold = original_threshold
# Restore original num_splits
if original_num_splits is not None:
impl._num_kv_splits = original_num_splits
return results
# ============================================================================
# Public API
# ============================================================================
def run_mla_benchmark(
backend: str,
config,
reorder_batch_threshold: int | None = None,
num_kv_splits: int | None = None,
) -> BenchmarkResult | list[BenchmarkResult]:
"""
Unified MLA benchmark runner for all backends.
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
Always uses batched execution internally for optimal performance.
Args:
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla)
config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples
reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA
(single config mode only)
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
Returns:
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
"""
# Normalize to batched mode: (config, threshold, num_splits)
if isinstance(config, list):
# Already in batched format
if len(config) > 0 and isinstance(config[0], tuple):
# Format: [(cfg, param), ...] where param is threshold or num_splits
if backend in ("flashattn_mla", "flashmla"):
configs_with_params = [(cfg, param, None) for cfg, param in config]
else: # cutlass_mla or flashinfer_mla
configs_with_params = [(cfg, None, param) for cfg, param in config]
else:
# Format: [cfg, ...] - just configs
configs_with_params = [(cfg, None, None) for cfg in config]
return_single = False
else:
# Single config: convert to batched format
configs_with_params = [(config, reorder_batch_threshold, num_kv_splits)]
return_single = True
# Use unified batched execution
results = _run_mla_benchmark_batched(backend, configs_with_params)
# Return single result or list based on input
return results[0] if return_single else results