482 lines
14 KiB
Python
482 lines
14 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
|
||
|
|
"""
|
||
|
|
Standard attention benchmark runner - shared utilities for non-MLA benchmarks.
|
||
|
|
|
||
|
|
This module provides helpers for running standard attention backends
|
||
|
|
(FlashAttention, Triton, FlashInfer) with real vLLM integration.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import types
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
from batch_spec import parse_batch_spec, reorder_for_flashinfer
|
||
|
|
from common import BenchmarkConfig, BenchmarkResult, MockLayer, get_attention_scale
|
||
|
|
|
||
|
|
from vllm.config import (
|
||
|
|
CacheConfig,
|
||
|
|
CompilationConfig,
|
||
|
|
DeviceConfig,
|
||
|
|
LoadConfig,
|
||
|
|
ModelConfig,
|
||
|
|
ParallelConfig,
|
||
|
|
SchedulerConfig,
|
||
|
|
VllmConfig,
|
||
|
|
)
|
||
|
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||
|
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Backend Configuration
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
_BACKEND_CONFIG = {
|
||
|
|
"flash": {
|
||
|
|
"module": "vllm.v1.attention.backends.flash_attn",
|
||
|
|
"backend_class": "FlashAttentionBackend",
|
||
|
|
"dtype": torch.float16,
|
||
|
|
"cache_layout": "standard",
|
||
|
|
# ^ [2, num_blocks, block_size, num_kv_heads, head_dim]
|
||
|
|
},
|
||
|
|
"triton": {
|
||
|
|
"module": "vllm.v1.attention.backends.triton_attn",
|
||
|
|
"backend_class": "TritonAttentionBackend",
|
||
|
|
"dtype": torch.float32,
|
||
|
|
"cache_layout": "standard",
|
||
|
|
},
|
||
|
|
"flashinfer": {
|
||
|
|
"module": "vllm.v1.attention.backends.flashinfer",
|
||
|
|
"backend_class": "FlashInferBackend",
|
||
|
|
"dtype": torch.float16,
|
||
|
|
"cache_layout": "flashinfer",
|
||
|
|
# ^ [num_blocks, 2, block_size, num_kv_heads, head_dim]
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _get_backend_config(backend: str) -> dict:
|
||
|
|
if backend not in _BACKEND_CONFIG:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unknown backend: {backend}. "
|
||
|
|
f"Available: {', '.join(_BACKEND_CONFIG.keys())}"
|
||
|
|
)
|
||
|
|
return _BACKEND_CONFIG[backend]
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Metadata Building Helpers
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def _build_common_attn_metadata(
|
||
|
|
q_lens: list[int],
|
||
|
|
kv_lens: list[int],
|
||
|
|
block_size: int,
|
||
|
|
device: torch.device,
|
||
|
|
) -> CommonAttentionMetadata:
|
||
|
|
"""Build CommonAttentionMetadata from query/kv lengths."""
|
||
|
|
batch_size = len(q_lens)
|
||
|
|
total_tokens = sum(q_lens)
|
||
|
|
|
||
|
|
query_start_loc = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
||
|
|
query_start_loc[1:] = torch.tensor(q_lens, dtype=torch.int32, device=device).cumsum(
|
||
|
|
0
|
||
|
|
)
|
||
|
|
query_start_loc_cpu = query_start_loc.cpu()
|
||
|
|
|
||
|
|
seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
|
||
|
|
seq_lens_cpu = seq_lens.cpu()
|
||
|
|
max_seq_len = int(seq_lens_cpu.max())
|
||
|
|
|
||
|
|
context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)]
|
||
|
|
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||
|
|
|
||
|
|
max_blocks = (max(kv_lens) + block_size - 1) // block_size
|
||
|
|
num_blocks = batch_size * max_blocks
|
||
|
|
block_table_tensor = torch.arange(
|
||
|
|
num_blocks, dtype=torch.int32, device=device
|
||
|
|
).view(batch_size, max_blocks)
|
||
|
|
slot_mapping = torch.arange(total_tokens, dtype=torch.int64, device=device)
|
||
|
|
|
||
|
|
max_query_len = max(q_lens)
|
||
|
|
|
||
|
|
return CommonAttentionMetadata(
|
||
|
|
query_start_loc=query_start_loc,
|
||
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
||
|
|
seq_lens=seq_lens,
|
||
|
|
seq_lens_cpu=seq_lens_cpu,
|
||
|
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||
|
|
num_reqs=batch_size,
|
||
|
|
num_actual_tokens=total_tokens,
|
||
|
|
max_query_len=max_query_len,
|
||
|
|
max_seq_len=max_seq_len,
|
||
|
|
block_table_tensor=block_table_tensor,
|
||
|
|
slot_mapping=slot_mapping,
|
||
|
|
causal=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _create_vllm_config(
|
||
|
|
config: BenchmarkConfig,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
max_num_blocks: int,
|
||
|
|
) -> VllmConfig:
|
||
|
|
"""Create a VllmConfig for benchmarking with mock model methods."""
|
||
|
|
model_config = ModelConfig(
|
||
|
|
model="meta-llama/Meta-Llama-3-8B",
|
||
|
|
tokenizer="meta-llama/Meta-Llama-3-8B",
|
||
|
|
trust_remote_code=False,
|
||
|
|
dtype=dtype,
|
||
|
|
seed=0,
|
||
|
|
max_model_len=1024,
|
||
|
|
)
|
||
|
|
|
||
|
|
cache_config = CacheConfig(
|
||
|
|
block_size=config.block_size,
|
||
|
|
cache_dtype="auto",
|
||
|
|
swap_space=0,
|
||
|
|
)
|
||
|
|
cache_config.num_gpu_blocks = max_num_blocks
|
||
|
|
cache_config.num_cpu_blocks = 0
|
||
|
|
|
||
|
|
parallel_config = ParallelConfig(tensor_parallel_size=1)
|
||
|
|
scheduler_config = SchedulerConfig(
|
||
|
|
max_num_seqs=256,
|
||
|
|
max_num_batched_tokens=8192,
|
||
|
|
max_model_len=8192,
|
||
|
|
is_encoder_decoder=False,
|
||
|
|
enable_chunked_prefill=True,
|
||
|
|
)
|
||
|
|
device_config = DeviceConfig()
|
||
|
|
load_config = LoadConfig()
|
||
|
|
compilation_config = CompilationConfig()
|
||
|
|
|
||
|
|
# Add mock methods for benchmark config values
|
||
|
|
model_config.get_num_layers = types.MethodType(
|
||
|
|
lambda self: config.num_layers, model_config
|
||
|
|
)
|
||
|
|
model_config.get_sliding_window_for_layer = types.MethodType(
|
||
|
|
lambda self, i: None, model_config
|
||
|
|
)
|
||
|
|
model_config.get_logits_soft_cap_for_layer = types.MethodType(
|
||
|
|
lambda self, i: 0.0, model_config
|
||
|
|
)
|
||
|
|
model_config.get_sm_scale_for_layer = types.MethodType(
|
||
|
|
lambda self, i: 1.0 / config.head_dim**0.5, model_config
|
||
|
|
)
|
||
|
|
model_config.get_num_attention_heads = types.MethodType(
|
||
|
|
lambda self, parallel_config=None: config.num_q_heads, model_config
|
||
|
|
)
|
||
|
|
model_config.get_num_kv_heads = types.MethodType(
|
||
|
|
lambda self, parallel_config=None: config.num_kv_heads, model_config
|
||
|
|
)
|
||
|
|
model_config.get_head_size = types.MethodType(
|
||
|
|
lambda self: config.head_dim, model_config
|
||
|
|
)
|
||
|
|
model_config.get_sliding_window = types.MethodType(lambda self: None, model_config)
|
||
|
|
|
||
|
|
return VllmConfig(
|
||
|
|
model_config=model_config,
|
||
|
|
cache_config=cache_config,
|
||
|
|
parallel_config=parallel_config,
|
||
|
|
scheduler_config=scheduler_config,
|
||
|
|
device_config=device_config,
|
||
|
|
load_config=load_config,
|
||
|
|
compilation_config=compilation_config,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Backend Initialization
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def _create_backend_impl(
|
||
|
|
backend_cfg: dict,
|
||
|
|
config: BenchmarkConfig,
|
||
|
|
device: torch.device,
|
||
|
|
):
|
||
|
|
"""Create backend implementation instance."""
|
||
|
|
import importlib
|
||
|
|
|
||
|
|
backend_module = importlib.import_module(backend_cfg["module"])
|
||
|
|
backend_class = getattr(backend_module, backend_cfg["backend_class"])
|
||
|
|
|
||
|
|
scale = get_attention_scale(config.head_dim)
|
||
|
|
dtype = backend_cfg["dtype"]
|
||
|
|
|
||
|
|
impl = backend_class.get_impl_cls()(
|
||
|
|
num_heads=config.num_q_heads,
|
||
|
|
head_size=config.head_dim,
|
||
|
|
scale=scale,
|
||
|
|
num_kv_heads=config.num_kv_heads,
|
||
|
|
alibi_slopes=None,
|
||
|
|
sliding_window=None,
|
||
|
|
kv_cache_dtype="auto",
|
||
|
|
)
|
||
|
|
|
||
|
|
kv_cache_spec = FullAttentionSpec(
|
||
|
|
block_size=config.block_size,
|
||
|
|
num_kv_heads=config.num_kv_heads,
|
||
|
|
head_size=config.head_dim,
|
||
|
|
dtype=dtype,
|
||
|
|
)
|
||
|
|
|
||
|
|
layer = MockLayer(device, kv_cache_spec=kv_cache_spec)
|
||
|
|
|
||
|
|
return backend_class, impl, layer, dtype
|
||
|
|
|
||
|
|
|
||
|
|
def _create_metadata_builder(
|
||
|
|
backend_class,
|
||
|
|
kv_cache_spec: FullAttentionSpec,
|
||
|
|
vllm_config: VllmConfig,
|
||
|
|
device: torch.device,
|
||
|
|
):
|
||
|
|
"""Create metadata builder instance."""
|
||
|
|
return backend_class.get_builder_cls()(
|
||
|
|
kv_cache_spec=kv_cache_spec,
|
||
|
|
layer_names=["layer_0"],
|
||
|
|
vllm_config=vllm_config,
|
||
|
|
device=device,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Tensor Creation Helpers
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def _create_input_tensors(
|
||
|
|
config: BenchmarkConfig,
|
||
|
|
total_q: int,
|
||
|
|
device: torch.device,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
) -> tuple:
|
||
|
|
"""Create Q, K, V input tensors for all layers."""
|
||
|
|
q_list = [
|
||
|
|
torch.randn(
|
||
|
|
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
|
||
|
|
)
|
||
|
|
for _ in range(config.num_layers)
|
||
|
|
]
|
||
|
|
k_list = [
|
||
|
|
torch.randn(
|
||
|
|
total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype
|
||
|
|
)
|
||
|
|
for _ in range(config.num_layers)
|
||
|
|
]
|
||
|
|
v_list = [
|
||
|
|
torch.randn(
|
||
|
|
total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype
|
||
|
|
)
|
||
|
|
for _ in range(config.num_layers)
|
||
|
|
]
|
||
|
|
return q_list, k_list, v_list
|
||
|
|
|
||
|
|
|
||
|
|
def _create_kv_cache(
|
||
|
|
config: BenchmarkConfig,
|
||
|
|
max_num_blocks: int,
|
||
|
|
cache_layout: str,
|
||
|
|
device: torch.device,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
) -> list:
|
||
|
|
"""Create KV cache tensors for all layers."""
|
||
|
|
if cache_layout == "flashinfer":
|
||
|
|
# FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim]
|
||
|
|
cache_list = [
|
||
|
|
torch.zeros(
|
||
|
|
max_num_blocks,
|
||
|
|
2,
|
||
|
|
config.block_size,
|
||
|
|
config.num_kv_heads,
|
||
|
|
config.head_dim,
|
||
|
|
device=device,
|
||
|
|
dtype=dtype,
|
||
|
|
)
|
||
|
|
for _ in range(config.num_layers)
|
||
|
|
]
|
||
|
|
else:
|
||
|
|
# Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim]
|
||
|
|
cache_list = [
|
||
|
|
torch.zeros(
|
||
|
|
2,
|
||
|
|
max_num_blocks,
|
||
|
|
config.block_size,
|
||
|
|
config.num_kv_heads,
|
||
|
|
config.head_dim,
|
||
|
|
device=device,
|
||
|
|
dtype=dtype,
|
||
|
|
)
|
||
|
|
for _ in range(config.num_layers)
|
||
|
|
]
|
||
|
|
return cache_list
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Benchmark Execution
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def _run_single_benchmark(
|
||
|
|
config: BenchmarkConfig,
|
||
|
|
impl,
|
||
|
|
layer,
|
||
|
|
q_list: list,
|
||
|
|
k_list: list,
|
||
|
|
v_list: list,
|
||
|
|
cache_list: list,
|
||
|
|
attn_metadata,
|
||
|
|
device: torch.device,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
) -> tuple:
|
||
|
|
"""Run single benchmark iteration with warmup and timing loop."""
|
||
|
|
total_q = q_list[0].shape[0]
|
||
|
|
out = torch.empty(
|
||
|
|
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
|
||
|
|
)
|
||
|
|
|
||
|
|
# Warmup
|
||
|
|
for _ in range(config.warmup_iters):
|
||
|
|
for i in range(config.num_layers):
|
||
|
|
impl.forward(
|
||
|
|
layer,
|
||
|
|
q_list[i],
|
||
|
|
k_list[i],
|
||
|
|
v_list[i],
|
||
|
|
cache_list[i],
|
||
|
|
attn_metadata,
|
||
|
|
output=out,
|
||
|
|
)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
|
||
|
|
# Benchmark
|
||
|
|
times = []
|
||
|
|
for _ in range(config.repeats):
|
||
|
|
start = torch.cuda.Event(enable_timing=True)
|
||
|
|
end = torch.cuda.Event(enable_timing=True)
|
||
|
|
|
||
|
|
start.record()
|
||
|
|
for i in range(config.num_layers):
|
||
|
|
impl.forward(
|
||
|
|
layer,
|
||
|
|
q_list[i],
|
||
|
|
k_list[i],
|
||
|
|
v_list[i],
|
||
|
|
cache_list[i],
|
||
|
|
attn_metadata,
|
||
|
|
output=out,
|
||
|
|
)
|
||
|
|
end.record()
|
||
|
|
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
elapsed_ms = start.elapsed_time(end)
|
||
|
|
times.append(elapsed_ms / 1000.0 / config.num_layers) # seconds per layer
|
||
|
|
|
||
|
|
mem_stats = {}
|
||
|
|
if config.profile_memory:
|
||
|
|
mem_stats = {
|
||
|
|
"allocated_mb": torch.cuda.memory_allocated(device) / 1024**2,
|
||
|
|
"reserved_mb": torch.cuda.memory_reserved(device) / 1024**2,
|
||
|
|
}
|
||
|
|
|
||
|
|
return times, mem_stats
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Public API
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||
|
|
"""
|
||
|
|
Run standard attention benchmark with real kernels.
|
||
|
|
|
||
|
|
Supports: flash, triton, flashinfer
|
||
|
|
|
||
|
|
Args:
|
||
|
|
config: Benchmark configuration
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
BenchmarkResult with timing and memory statistics
|
||
|
|
"""
|
||
|
|
device = torch.device(config.device)
|
||
|
|
torch.cuda.set_device(device)
|
||
|
|
|
||
|
|
backend_cfg = _get_backend_config(config.backend)
|
||
|
|
|
||
|
|
requests = parse_batch_spec(config.batch_spec)
|
||
|
|
|
||
|
|
if config.backend == "flashinfer":
|
||
|
|
requests = reorder_for_flashinfer(requests)
|
||
|
|
|
||
|
|
q_lens = [r.q_len for r in requests]
|
||
|
|
kv_lens = [r.kv_len for r in requests]
|
||
|
|
total_q = sum(q_lens)
|
||
|
|
max_kv = max(kv_lens)
|
||
|
|
|
||
|
|
max_num_blocks = (max_kv + config.block_size - 1) // config.block_size
|
||
|
|
|
||
|
|
backend_class, impl, layer, dtype = _create_backend_impl(
|
||
|
|
backend_cfg, config, device
|
||
|
|
)
|
||
|
|
|
||
|
|
common_metadata = _build_common_attn_metadata(
|
||
|
|
q_lens, kv_lens, config.block_size, device
|
||
|
|
)
|
||
|
|
|
||
|
|
kv_cache_spec = FullAttentionSpec(
|
||
|
|
block_size=config.block_size,
|
||
|
|
num_kv_heads=config.num_kv_heads,
|
||
|
|
head_size=config.head_dim,
|
||
|
|
dtype=dtype,
|
||
|
|
)
|
||
|
|
|
||
|
|
vllm_config = _create_vllm_config(config, dtype, max_num_blocks)
|
||
|
|
|
||
|
|
builder = _create_metadata_builder(
|
||
|
|
backend_class, kv_cache_spec, vllm_config, device
|
||
|
|
)
|
||
|
|
|
||
|
|
attn_metadata = builder.build(
|
||
|
|
common_prefix_len=0,
|
||
|
|
common_attn_metadata=common_metadata,
|
||
|
|
)
|
||
|
|
|
||
|
|
q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype)
|
||
|
|
|
||
|
|
cache_list = _create_kv_cache(
|
||
|
|
config, max_num_blocks, backend_cfg["cache_layout"], device, dtype
|
||
|
|
)
|
||
|
|
|
||
|
|
times, mem_stats = _run_single_benchmark(
|
||
|
|
config,
|
||
|
|
impl,
|
||
|
|
layer,
|
||
|
|
q_list,
|
||
|
|
k_list,
|
||
|
|
v_list,
|
||
|
|
cache_list,
|
||
|
|
attn_metadata,
|
||
|
|
device,
|
||
|
|
dtype,
|
||
|
|
)
|
||
|
|
|
||
|
|
mean_time = np.mean(times)
|
||
|
|
throughput = total_q / mean_time if mean_time > 0 else 0
|
||
|
|
|
||
|
|
return BenchmarkResult(
|
||
|
|
config=config,
|
||
|
|
mean_time=mean_time,
|
||
|
|
std_time=np.std(times),
|
||
|
|
min_time=np.min(times),
|
||
|
|
max_time=np.max(times),
|
||
|
|
throughput_tokens_per_sec=throughput,
|
||
|
|
memory_allocated_mb=mem_stats.get("allocated_mb"),
|
||
|
|
memory_reserved_mb=mem_stats.get("reserved_mb"),
|
||
|
|
)
|