Add attention benchmarking tools (#26835)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
836
benchmarks/attention_benchmarks/mla_runner.py
Normal file
836
benchmarks/attention_benchmarks/mla_runner.py
Normal file
@@ -0,0 +1,836 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
MLA benchmark runner - shared utilities for MLA benchmarks.
|
||||
|
||||
This module provides helpers for running MLA backends without
|
||||
needing full VllmConfig integration.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from batch_spec import parse_batch_spec
|
||||
from common import (
|
||||
BenchmarkResult,
|
||||
MockHfConfig,
|
||||
MockKVBProj,
|
||||
MockLayer,
|
||||
setup_mla_dims,
|
||||
)
|
||||
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# VllmConfig Creation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _add_mock_methods_to_model_config(model_config: ModelConfig) -> None:
|
||||
"""
|
||||
Add mock methods for layer-specific queries to ModelConfig.
|
||||
|
||||
These methods are needed by metadata builders but aren't normally
|
||||
present on ModelConfig when used in benchmark contexts.
|
||||
"""
|
||||
import types
|
||||
|
||||
model_config.get_num_layers = types.MethodType(lambda self: 1, model_config)
|
||||
model_config.get_sliding_window_for_layer = types.MethodType(
|
||||
lambda self, _i: None, model_config
|
||||
)
|
||||
model_config.get_logits_soft_cap_for_layer = types.MethodType(
|
||||
lambda self, _i: None, model_config
|
||||
)
|
||||
model_config.get_sm_scale_for_layer = types.MethodType(
|
||||
lambda self, _i: 1.0 / model_config.get_head_size() ** 0.5, model_config
|
||||
)
|
||||
|
||||
|
||||
def create_minimal_vllm_config(
|
||||
model_name: str = "deepseek-v3",
|
||||
block_size: int = 128,
|
||||
max_num_seqs: int = 256,
|
||||
mla_dims: dict | None = None,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create minimal VllmConfig for MLA benchmarks.
|
||||
|
||||
Args:
|
||||
model_name: Model name (deepseek-v2, deepseek-v3, etc.) - used if mla_dims not
|
||||
provided
|
||||
block_size: KV cache block size
|
||||
max_num_seqs: Maximum number of sequences
|
||||
mla_dims: Optional custom MLA dimensions dict. If not provided, uses
|
||||
setup_mla_dims(model_name)
|
||||
|
||||
Returns:
|
||||
VllmConfig for benchmarking
|
||||
"""
|
||||
# Get MLA dimensions - use provided or load from model name
|
||||
if mla_dims is None:
|
||||
mla_dims = setup_mla_dims(model_name)
|
||||
|
||||
# Create mock HF config first (avoids downloading from HuggingFace)
|
||||
mock_hf_config = MockHfConfig(mla_dims)
|
||||
|
||||
# Create a temporary minimal config.json to avoid HF downloads
|
||||
# This ensures consistent ModelConfig construction without network access
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
minimal_config = {
|
||||
"architectures": ["DeepseekV2ForCausalLM"],
|
||||
"model_type": "deepseek_v2",
|
||||
"num_attention_heads": mla_dims["num_q_heads"],
|
||||
"num_key_value_heads": mla_dims["num_kv_heads"],
|
||||
"hidden_size": mla_dims["head_dim"] * mla_dims["num_q_heads"],
|
||||
"torch_dtype": "bfloat16",
|
||||
"max_position_embeddings": 163840, # DeepSeek V3 default
|
||||
"rope_theta": 10000.0,
|
||||
"vocab_size": 128256,
|
||||
}
|
||||
|
||||
# Create temporary directory with config.json
|
||||
temp_dir = tempfile.mkdtemp(prefix="vllm_bench_")
|
||||
config_path = os.path.join(temp_dir, "config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(minimal_config, f)
|
||||
|
||||
try:
|
||||
# Create model config using local path - no HF downloads
|
||||
model_config = ModelConfig(
|
||||
model=temp_dir, # Use local temp directory
|
||||
tokenizer=None,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
seed=0,
|
||||
max_model_len=32768,
|
||||
quantization=None,
|
||||
quantization_param_path=None,
|
||||
enforce_eager=False,
|
||||
max_context_len_to_capture=None,
|
||||
max_seq_len_to_capture=8192,
|
||||
max_logprobs=20,
|
||||
disable_sliding_window=False,
|
||||
skip_tokenizer_init=True,
|
||||
served_model_name=None,
|
||||
limit_mm_per_prompt=None,
|
||||
use_async_output_proc=True,
|
||||
config_format="auto",
|
||||
)
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
# Override with our mock config
|
||||
model_config.hf_config = mock_hf_config
|
||||
model_config.hf_text_config = mock_hf_config
|
||||
|
||||
# Add mock methods for layer-specific queries
|
||||
_add_mock_methods_to_model_config(model_config)
|
||||
|
||||
# Create sub-configs
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=32768,
|
||||
is_encoder_decoder=False,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
compilation_config = CompilationConfig()
|
||||
|
||||
return VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Backend Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Backend name to class name prefix mapping
|
||||
_BACKEND_NAME_MAP = {
|
||||
"flashattn_mla": "FlashAttnMLA",
|
||||
"flashmla": "FlashMLA",
|
||||
"flashinfer_mla": "FlashInferMLA",
|
||||
"cutlass_mla": "CutlassMLA",
|
||||
}
|
||||
|
||||
# Special properties that differ from defaults
|
||||
_BACKEND_PROPERTIES = {
|
||||
"flashmla": {
|
||||
"query_format": "concat", # Single concatenated tensor (vs tuple)
|
||||
"block_size": 64, # FlashMLA uses fixed block size
|
||||
},
|
||||
"flashinfer_mla": {
|
||||
"block_size": 64, # FlashInfer MLA only supports 32 or 64
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_config(backend: str) -> dict:
|
||||
"""
|
||||
Get backend configuration using naming conventions.
|
||||
|
||||
All MLA backends follow the pattern:
|
||||
- Module: vllm.v1.attention.backends.mla.{backend}
|
||||
- Impl: {Name}Impl
|
||||
- Metadata: {Name}Metadata (or MLACommonMetadata)
|
||||
- DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata)
|
||||
- MetadataBuilder: {Name}MetadataBuilder
|
||||
"""
|
||||
if backend not in _BACKEND_NAME_MAP:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
name = _BACKEND_NAME_MAP[backend]
|
||||
props = _BACKEND_PROPERTIES.get(backend, {})
|
||||
|
||||
# Check if backend uses common metadata (FlashInfer, CUTLASS)
|
||||
uses_common = backend in ("flashinfer_mla", "cutlass_mla")
|
||||
|
||||
return {
|
||||
"module": f"vllm.v1.attention.backends.mla.{backend}",
|
||||
"impl_class": f"{name}Impl",
|
||||
"metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata",
|
||||
"decode_metadata_class": "MLACommonDecodeMetadata"
|
||||
if uses_common
|
||||
else f"{name}DecodeMetadata",
|
||||
"builder_class": f"{name}MetadataBuilder",
|
||||
"query_format": props.get("query_format", "tuple"),
|
||||
"block_size": props.get("block_size", None),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metadata Building Helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _build_attention_metadata(
|
||||
requests: list,
|
||||
block_size: int,
|
||||
device: torch.device,
|
||||
builder_instance,
|
||||
) -> tuple:
|
||||
"""
|
||||
Build attention metadata from batch requests.
|
||||
|
||||
Args:
|
||||
requests: List of BatchRequest objects
|
||||
block_size: KV cache block size
|
||||
device: Target device
|
||||
builder_instance: Metadata builder instance
|
||||
|
||||
Returns:
|
||||
Tuple of (metadata, kv_cache_num_blocks)
|
||||
"""
|
||||
q_lens = [r.q_len for r in requests]
|
||||
kv_lens = [r.kv_len for r in requests]
|
||||
total_q = sum(q_lens)
|
||||
max_kv = max(kv_lens)
|
||||
|
||||
# Build query start locations
|
||||
q_start_cpu = torch.tensor(
|
||||
[0] + [sum(q_lens[: i + 1]) for i in range(len(q_lens))],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
q_start_gpu = q_start_cpu.to(device)
|
||||
|
||||
# Build sequence lengths
|
||||
seq_lens_cpu = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
seq_lens_gpu = seq_lens_cpu.to(device)
|
||||
|
||||
# Build num_computed_tokens (context length for each request)
|
||||
context_lens = [kv_len - q_len for q_len, kv_len in zip(q_lens, kv_lens)]
|
||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||
|
||||
# Build block table
|
||||
num_blocks_per_req = [(kv + block_size - 1) // block_size for kv in kv_lens]
|
||||
max_num_blocks = max(num_blocks_per_req)
|
||||
|
||||
block_table_cpu = np.zeros((len(requests), max_num_blocks), dtype=np.int32)
|
||||
current_block = 0
|
||||
for i, num_blocks in enumerate(num_blocks_per_req):
|
||||
for j in range(num_blocks):
|
||||
block_table_cpu[i, j] = current_block
|
||||
current_block += 1
|
||||
|
||||
block_table_gpu = torch.from_numpy(block_table_cpu).to(device)
|
||||
|
||||
# Build slot mapping
|
||||
slot_mapping_list = []
|
||||
for i, (q_len, kv_len, num_blocks) in enumerate(
|
||||
zip(q_lens, kv_lens, num_blocks_per_req)
|
||||
):
|
||||
context_len = kv_len - q_len
|
||||
for j in range(q_len):
|
||||
token_kv_idx = context_len + j
|
||||
block_idx = token_kv_idx // block_size
|
||||
offset_in_block = token_kv_idx % block_size
|
||||
global_block_id = block_table_cpu[i, block_idx]
|
||||
slot_id = global_block_id * block_size + offset_in_block
|
||||
slot_mapping_list.append(slot_id)
|
||||
|
||||
slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64, device=device)
|
||||
|
||||
# Create CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
num_reqs=len(requests),
|
||||
max_query_len=max(q_lens),
|
||||
max_seq_len=max_kv,
|
||||
num_actual_tokens=total_q,
|
||||
query_start_loc=q_start_gpu,
|
||||
query_start_loc_cpu=q_start_cpu,
|
||||
seq_lens=seq_lens_gpu,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
slot_mapping=slot_mapping,
|
||||
block_table_tensor=block_table_gpu,
|
||||
dcp_local_seq_lens=None,
|
||||
)
|
||||
|
||||
# Use the production build() method
|
||||
metadata = builder_instance.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
fast_build=False,
|
||||
)
|
||||
|
||||
return metadata, current_block
|
||||
|
||||
|
||||
def _create_input_tensors(
|
||||
total_q: int,
|
||||
mla_dims: dict,
|
||||
query_format: str,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Create input tensors for both decode and prefill modes.
|
||||
|
||||
MLA requires different tensor formats for decode vs prefill:
|
||||
- Decode: Uses kv_lora_rank (512) dimension
|
||||
- Prefill: Uses qk_nope_head_dim (128) to stay under FlashAttention's 256 limit
|
||||
|
||||
Args:
|
||||
total_q: Total number of query tokens
|
||||
mla_dims: MLA dimension configuration
|
||||
query_format: Either "tuple" or "concat"
|
||||
device: Target device
|
||||
dtype: Tensor dtype
|
||||
|
||||
Returns:
|
||||
Tuple of (decode_inputs, prefill_inputs)
|
||||
- decode_inputs: Query tensor(s) for decode mode
|
||||
- prefill_inputs: Dict with 'q', 'k_c_normed', 'k_pe', 'k_scale' for prefill
|
||||
"""
|
||||
if query_format == "tuple":
|
||||
# Decode mode format: (q_nope, q_pe) where q_nope has kv_lora_rank dim
|
||||
q_nope_decode = torch.randn(
|
||||
total_q,
|
||||
mla_dims["num_q_heads"],
|
||||
mla_dims["kv_lora_rank"],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
q_pe = torch.randn(
|
||||
total_q,
|
||||
mla_dims["num_q_heads"],
|
||||
mla_dims["qk_rope_head_dim"],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
decode_inputs = (q_nope_decode, q_pe)
|
||||
|
||||
# For prefill, we need q with qk_nope_head_dim instead of kv_lora_rank
|
||||
q_nope_prefill = torch.randn(
|
||||
total_q,
|
||||
mla_dims["num_q_heads"],
|
||||
mla_dims["qk_nope_head_dim"],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
prefill_q = torch.cat([q_nope_prefill, q_pe], dim=-1)
|
||||
else: # concat
|
||||
decode_inputs = torch.randn(
|
||||
total_q,
|
||||
mla_dims["num_q_heads"],
|
||||
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# For prefill with concat format
|
||||
prefill_q = torch.randn(
|
||||
total_q,
|
||||
mla_dims["num_q_heads"],
|
||||
mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Create additional inputs needed for prefill forward
|
||||
k_c_normed = torch.randn(
|
||||
total_q,
|
||||
mla_dims["kv_lora_rank"],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
k_pe = torch.randn(
|
||||
total_q,
|
||||
1, # Single head for MLA
|
||||
mla_dims["qk_rope_head_dim"],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
k_scale = torch.ones(1, device=device, dtype=torch.float32)
|
||||
|
||||
output = torch.zeros(
|
||||
total_q,
|
||||
mla_dims["num_q_heads"] * mla_dims["v_head_dim"],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
prefill_inputs = {
|
||||
"q": prefill_q,
|
||||
"k_c_normed": k_c_normed,
|
||||
"k_pe": k_pe,
|
||||
"k_scale": k_scale,
|
||||
"output": output,
|
||||
}
|
||||
|
||||
return decode_inputs, prefill_inputs
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Backend Initialization
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _create_backend_impl(
|
||||
backend_cfg: dict,
|
||||
mla_dims: dict,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
"""
|
||||
Create backend implementation instance.
|
||||
|
||||
Args:
|
||||
backend_cfg: Backend configuration dict
|
||||
mla_dims: MLA dimension configuration
|
||||
vllm_config: VllmConfig instance
|
||||
device: Target device
|
||||
|
||||
Returns:
|
||||
Tuple of (impl, layer, builder_instance)
|
||||
"""
|
||||
# Import backend classes
|
||||
backend_module = importlib.import_module(backend_cfg["module"])
|
||||
impl_class = getattr(backend_module, backend_cfg["impl_class"])
|
||||
|
||||
# Calculate scale
|
||||
scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"])
|
||||
|
||||
# Create mock kv_b_proj layer for prefill mode
|
||||
mock_kv_b_proj = MockKVBProj(
|
||||
num_heads=mla_dims["num_q_heads"],
|
||||
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
|
||||
v_head_dim=mla_dims["v_head_dim"],
|
||||
)
|
||||
|
||||
# Create impl
|
||||
impl = impl_class(
|
||||
num_heads=mla_dims["num_q_heads"],
|
||||
head_size=mla_dims["head_dim"],
|
||||
scale=scale,
|
||||
num_kv_heads=mla_dims["num_kv_heads"],
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=mla_dims["kv_lora_rank"],
|
||||
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
|
||||
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
|
||||
qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
|
||||
v_head_dim=mla_dims["v_head_dim"],
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
)
|
||||
|
||||
# Initialize DCP attributes
|
||||
if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1):
|
||||
impl.dcp_world_size = 1
|
||||
impl.dcp_rank = 0
|
||||
|
||||
# Create KV cache spec for MockLayer
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=backend_cfg["block_size"] or vllm_config.cache_config.block_size,
|
||||
num_kv_heads=1, # MLA uses 1 KV head
|
||||
head_size=576, # MLA head dim
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Create mock layer
|
||||
layer = MockLayer(device, impl=impl, kv_cache_spec=kv_cache_spec)
|
||||
|
||||
# Create builder instance if needed
|
||||
builder_instance = None
|
||||
if backend_cfg["builder_class"]:
|
||||
builder_class = getattr(backend_module, backend_cfg["builder_class"])
|
||||
|
||||
# Populate static_forward_context so builder can find the layer
|
||||
# MockLayer inherits from AttentionLayerBase, so isinstance checks pass
|
||||
vllm_config.compilation_config.static_forward_context = {"placeholder": layer}
|
||||
|
||||
builder_instance = builder_class(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
layer_names=["placeholder"],
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return impl, layer, builder_instance
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _extract_mla_dims_from_config(config) -> dict | None:
|
||||
"""
|
||||
Extract MLA dimensions from BenchmarkConfig if all required fields are present.
|
||||
|
||||
Args:
|
||||
config: BenchmarkConfig instance
|
||||
|
||||
Returns:
|
||||
Dict with MLA dimensions if all fields are provided, None otherwise
|
||||
"""
|
||||
# Check if all MLA-specific fields are provided
|
||||
if all(
|
||||
[
|
||||
config.kv_lora_rank is not None,
|
||||
config.qk_nope_head_dim is not None,
|
||||
config.qk_rope_head_dim is not None,
|
||||
config.v_head_dim is not None,
|
||||
]
|
||||
):
|
||||
return {
|
||||
"kv_lora_rank": config.kv_lora_rank,
|
||||
"qk_nope_head_dim": config.qk_nope_head_dim,
|
||||
"qk_rope_head_dim": config.qk_rope_head_dim,
|
||||
"v_head_dim": config.v_head_dim,
|
||||
"num_q_heads": config.num_q_heads,
|
||||
"num_kv_heads": config.num_kv_heads,
|
||||
"head_dim": config.head_dim,
|
||||
}
|
||||
# Fallback: if MLA fields not fully specified, try to construct from basic fields
|
||||
elif config.head_dim == 576:
|
||||
# This looks like a DeepSeek MLA config, use standard dimensions with custom
|
||||
# head count
|
||||
return {
|
||||
"kv_lora_rank": 512,
|
||||
"qk_nope_head_dim": 128,
|
||||
"qk_rope_head_dim": 64,
|
||||
"v_head_dim": 128,
|
||||
"num_q_heads": config.num_q_heads,
|
||||
"num_kv_heads": config.num_kv_heads,
|
||||
"head_dim": config.head_dim,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Benchmark Execution
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _run_single_benchmark(
|
||||
config,
|
||||
impl,
|
||||
layer,
|
||||
builder_instance,
|
||||
backend_cfg: dict,
|
||||
mla_dims: dict,
|
||||
device: torch.device,
|
||||
) -> BenchmarkResult:
|
||||
"""
|
||||
Run a single benchmark iteration.
|
||||
|
||||
Args:
|
||||
config: BenchmarkConfig instance
|
||||
impl: Backend implementation instance
|
||||
layer: MockLayer instance
|
||||
builder_instance: Metadata builder instance
|
||||
backend_cfg: Backend configuration dict
|
||||
mla_dims: MLA dimension configuration
|
||||
device: Target device
|
||||
|
||||
Returns:
|
||||
BenchmarkResult with timing statistics
|
||||
"""
|
||||
# Parse batch spec
|
||||
requests = parse_batch_spec(config.batch_spec)
|
||||
q_lens = [r.q_len for r in requests]
|
||||
total_q = sum(q_lens)
|
||||
|
||||
# Determine block size
|
||||
block_size = backend_cfg["block_size"] or config.block_size
|
||||
|
||||
# Build metadata
|
||||
metadata, num_blocks = _build_attention_metadata(
|
||||
requests, block_size, device, builder_instance
|
||||
)
|
||||
|
||||
# Create KV cache
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Create input tensors for both decode and prefill modes
|
||||
decode_inputs, prefill_inputs = _create_input_tensors(
|
||||
total_q,
|
||||
mla_dims,
|
||||
backend_cfg["query_format"],
|
||||
device,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
# Determine which forward method to use based on metadata
|
||||
if metadata.decode is not None:
|
||||
forward_fn = lambda: impl._forward_decode(
|
||||
decode_inputs, kv_cache, metadata, layer
|
||||
)
|
||||
elif metadata.prefill is not None:
|
||||
forward_fn = lambda: impl._forward_prefill(
|
||||
prefill_inputs["q"],
|
||||
prefill_inputs["k_c_normed"],
|
||||
prefill_inputs["k_pe"],
|
||||
kv_cache,
|
||||
metadata,
|
||||
prefill_inputs["k_scale"],
|
||||
prefill_inputs["output"],
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Metadata has neither decode nor prefill metadata")
|
||||
|
||||
# Warmup
|
||||
for _ in range(config.warmup_iters):
|
||||
forward_fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(config.repeats):
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
for _ in range(config.num_layers):
|
||||
forward_fn()
|
||||
end.record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = start.elapsed_time(end)
|
||||
times.append(elapsed_ms / 1000.0 / config.num_layers)
|
||||
|
||||
mean_time = float(np.mean(times))
|
||||
return BenchmarkResult(
|
||||
config=config,
|
||||
mean_time=mean_time,
|
||||
std_time=float(np.std(times)),
|
||||
min_time=float(np.min(times)),
|
||||
max_time=float(np.max(times)),
|
||||
throughput_tokens_per_sec=total_q / mean_time if mean_time > 0 else 0,
|
||||
)
|
||||
|
||||
|
||||
def _run_mla_benchmark_batched(
|
||||
backend: str,
|
||||
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
|
||||
) -> list[BenchmarkResult]:
|
||||
"""
|
||||
Unified batched MLA benchmark runner for all backends.
|
||||
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
|
||||
|
||||
This function reuses backend initialization across multiple benchmarks
|
||||
to avoid setup/teardown overhead.
|
||||
|
||||
Args:
|
||||
backend: Backend name
|
||||
configs_with_params: List of (config, threshold, num_splits) tuples
|
||||
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
|
||||
- num_splits: num_kv_splits (CUTLASS only)
|
||||
|
||||
Returns:
|
||||
List of BenchmarkResult objects
|
||||
"""
|
||||
if not configs_with_params:
|
||||
return []
|
||||
|
||||
backend_cfg = _get_backend_config(backend)
|
||||
device = torch.device(configs_with_params[0][0].device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Determine block size
|
||||
config_block_size = configs_with_params[0][0].block_size
|
||||
block_size = backend_cfg["block_size"] or config_block_size
|
||||
|
||||
# Extract MLA dimensions from the first config
|
||||
first_config = configs_with_params[0][0]
|
||||
mla_dims = _extract_mla_dims_from_config(first_config)
|
||||
|
||||
# If config didn't provide MLA dims, fall back to default model
|
||||
if mla_dims is None:
|
||||
mla_dims = setup_mla_dims("deepseek-v3")
|
||||
|
||||
# Create and set vLLM config for MLA (reused across all benchmarks)
|
||||
vllm_config = create_minimal_vllm_config(
|
||||
model_name="deepseek-v3", # Used only for model path
|
||||
block_size=block_size,
|
||||
mla_dims=mla_dims, # Use custom dims from config or default
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Create backend impl, layer, and builder (reused across benchmarks)
|
||||
impl, layer, builder_instance = _create_backend_impl(
|
||||
backend_cfg, mla_dims, vllm_config, device
|
||||
)
|
||||
|
||||
# Run each benchmark with the shared impl
|
||||
for config, threshold, num_splits in configs_with_params:
|
||||
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
|
||||
original_threshold = None
|
||||
if threshold is not None and builder_instance:
|
||||
original_threshold = builder_instance.reorder_batch_threshold
|
||||
builder_instance.reorder_batch_threshold = threshold
|
||||
|
||||
# Set num_splits for CUTLASS
|
||||
original_num_splits = None
|
||||
if num_splits is not None and hasattr(impl, "_num_kv_splits"):
|
||||
original_num_splits = impl._num_kv_splits
|
||||
impl._num_kv_splits = num_splits
|
||||
|
||||
try:
|
||||
result = _run_single_benchmark(
|
||||
config,
|
||||
impl,
|
||||
layer,
|
||||
builder_instance,
|
||||
backend_cfg,
|
||||
mla_dims,
|
||||
device,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
finally:
|
||||
# Restore original threshold
|
||||
if original_threshold is not None:
|
||||
builder_instance.reorder_batch_threshold = original_threshold
|
||||
|
||||
# Restore original num_splits
|
||||
if original_num_splits is not None:
|
||||
impl._num_kv_splits = original_num_splits
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Public API
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def run_mla_benchmark(
|
||||
backend: str,
|
||||
config,
|
||||
reorder_batch_threshold: int | None = None,
|
||||
num_kv_splits: int | None = None,
|
||||
) -> BenchmarkResult | list[BenchmarkResult]:
|
||||
"""
|
||||
Unified MLA benchmark runner for all backends.
|
||||
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
|
||||
|
||||
Always uses batched execution internally for optimal performance.
|
||||
|
||||
Args:
|
||||
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla)
|
||||
config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples
|
||||
reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA
|
||||
(single config mode only)
|
||||
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
|
||||
|
||||
Returns:
|
||||
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
|
||||
"""
|
||||
# Normalize to batched mode: (config, threshold, num_splits)
|
||||
if isinstance(config, list):
|
||||
# Already in batched format
|
||||
if len(config) > 0 and isinstance(config[0], tuple):
|
||||
# Format: [(cfg, param), ...] where param is threshold or num_splits
|
||||
if backend in ("flashattn_mla", "flashmla"):
|
||||
configs_with_params = [(cfg, param, None) for cfg, param in config]
|
||||
else: # cutlass_mla or flashinfer_mla
|
||||
configs_with_params = [(cfg, None, param) for cfg, param in config]
|
||||
else:
|
||||
# Format: [cfg, ...] - just configs
|
||||
configs_with_params = [(cfg, None, None) for cfg in config]
|
||||
return_single = False
|
||||
else:
|
||||
# Single config: convert to batched format
|
||||
configs_with_params = [(config, reorder_batch_threshold, num_kv_splits)]
|
||||
return_single = True
|
||||
|
||||
# Use unified batched execution
|
||||
results = _run_mla_benchmark_batched(backend, configs_with_params)
|
||||
|
||||
# Return single result or list based on input
|
||||
return results[0] if return_single else results
|
||||
Reference in New Issue
Block a user