[Benchmark] Improvements to attention benchmark script (#37115)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
@@ -60,9 +60,11 @@ def create_minimal_vllm_config(
|
||||
model_name: str = "deepseek-v3",
|
||||
block_size: int = 128,
|
||||
max_num_seqs: int = 256,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
mla_dims: dict | None = None,
|
||||
index_topk: int | None = None,
|
||||
prefill_backend: str | None = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create minimal VllmConfig for MLA benchmarks.
|
||||
@@ -149,13 +151,13 @@ def create_minimal_vllm_config(
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
cache_dtype="auto",
|
||||
cache_dtype=kv_cache_dtype,
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=8192,
|
||||
max_num_batched_tokens=max(max_num_batched_tokens, max_num_seqs),
|
||||
max_model_len=32768,
|
||||
is_encoder_decoder=False,
|
||||
enable_chunked_prefill=True,
|
||||
@@ -535,6 +537,7 @@ def _create_backend_impl(
|
||||
device: torch.device,
|
||||
max_num_tokens: int = 8192,
|
||||
index_topk: int | None = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
):
|
||||
"""
|
||||
Create backend implementation instance.
|
||||
@@ -583,7 +586,7 @@ def _create_backend_impl(
|
||||
"num_kv_heads": mla_dims["num_kv_heads"],
|
||||
"alibi_slopes": None,
|
||||
"sliding_window": None,
|
||||
"kv_cache_dtype": "auto",
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"logits_soft_cap": None,
|
||||
"attn_type": "decoder",
|
||||
"kv_sharing_target_layer_name": None,
|
||||
@@ -701,6 +704,7 @@ def _run_single_benchmark(
|
||||
mla_dims: dict,
|
||||
device: torch.device,
|
||||
indexer=None,
|
||||
kv_cache_dtype: str | None = None,
|
||||
) -> BenchmarkResult:
|
||||
"""
|
||||
Run a single benchmark iteration.
|
||||
@@ -734,49 +738,124 @@ def _run_single_benchmark(
|
||||
)
|
||||
|
||||
# Create KV cache
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
if kv_cache_dtype is None:
|
||||
kv_cache_dtype = getattr(config, "kv_cache_dtype", "auto")
|
||||
head_size = mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"]
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# FlashMLA sparse custom format: 656 bytes per token, stored as uint8.
|
||||
# Layout: kv_lora_rank fp8 bytes + 4 float32 tile scales
|
||||
# + 2*rope_dim bf16 bytes
|
||||
# = 512 + 16 + 128 = 656 bytes for DeepSeek dims.
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
656,
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
elif kv_cache_dtype == "fp8":
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Create input tensors for both decode and prefill modes
|
||||
decode_inputs, prefill_inputs = _create_input_tensors(
|
||||
total_q,
|
||||
mla_dims,
|
||||
backend_cfg["query_format"],
|
||||
device,
|
||||
torch.bfloat16,
|
||||
)
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
).view(current_platform.fp8_dtype())
|
||||
else:
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Fill indexer with random indices for sparse backends
|
||||
is_sparse = backend_cfg.get("is_sparse", False)
|
||||
if is_sparse and indexer is not None:
|
||||
indexer.fill_random_indices(total_q, max_kv_len)
|
||||
|
||||
# Determine which forward method to use based on metadata
|
||||
if metadata.decode is not None:
|
||||
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
|
||||
elif metadata.prefill is not None:
|
||||
forward_fn = lambda: impl.forward_mha(
|
||||
prefill_inputs["q"],
|
||||
prefill_inputs["k_c_normed"],
|
||||
prefill_inputs["k_pe"],
|
||||
kv_cache,
|
||||
metadata,
|
||||
prefill_inputs["k_scale"],
|
||||
prefill_inputs["output"],
|
||||
)
|
||||
else:
|
||||
# Determine which forward methods to use based on metadata.
|
||||
# Sparse MLA backends always use forward_mqa
|
||||
has_decode = is_sparse or getattr(metadata, "decode", None) is not None
|
||||
has_prefill = not is_sparse and getattr(metadata, "prefill", None) is not None
|
||||
if not has_decode and not has_prefill:
|
||||
raise RuntimeError("Metadata has neither decode nor prefill metadata")
|
||||
|
||||
num_decode = (
|
||||
metadata.num_decode_tokens
|
||||
if (has_decode and has_prefill)
|
||||
else total_q
|
||||
if has_decode
|
||||
else 0
|
||||
)
|
||||
num_prefill = total_q - num_decode
|
||||
|
||||
# Some backends requires fp8 queries when using fp8 KV cache.
|
||||
is_fp8_kvcache = kv_cache_dtype.startswith("fp8")
|
||||
quantize_query = is_fp8_kvcache and getattr(
|
||||
impl, "supports_quant_query_input", False
|
||||
)
|
||||
|
||||
# quantize_query forces concat format
|
||||
query_fmt = "concat" if quantize_query else backend_cfg["query_format"]
|
||||
|
||||
# Create decode query tensors
|
||||
if has_decode:
|
||||
decode_inputs, _ = _create_input_tensors(
|
||||
num_decode, mla_dims, query_fmt, device, torch.bfloat16
|
||||
)
|
||||
# Cast decode query to fp8 if the backend supports it
|
||||
if quantize_query:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if isinstance(decode_inputs, tuple):
|
||||
decode_inputs = torch.cat(list(decode_inputs), dim=-1)
|
||||
decode_inputs = decode_inputs.to(current_platform.fp8_dtype())
|
||||
|
||||
# Create prefill input tensors
|
||||
if has_prefill:
|
||||
_, prefill_inputs = _create_input_tensors(
|
||||
num_prefill, mla_dims, query_fmt, device, torch.bfloat16
|
||||
)
|
||||
|
||||
# Build forward function
|
||||
def forward_fn():
|
||||
results = []
|
||||
if has_decode:
|
||||
results.append(impl.forward_mqa(decode_inputs, kv_cache, metadata, layer))
|
||||
if has_prefill:
|
||||
results.append(
|
||||
impl.forward_mha(
|
||||
prefill_inputs["q"],
|
||||
prefill_inputs["k_c_normed"],
|
||||
prefill_inputs["k_pe"],
|
||||
kv_cache,
|
||||
metadata,
|
||||
prefill_inputs["k_scale"],
|
||||
prefill_inputs["output"],
|
||||
)
|
||||
)
|
||||
return results[0] if len(results) == 1 else tuple(results)
|
||||
|
||||
# Warmup
|
||||
for _ in range(config.warmup_iters):
|
||||
forward_fn()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
# Optionally capture a CUDA graph after warmup.
|
||||
# Graph replay eliminates CPU launch overhead so timings reflect pure
|
||||
# kernel time.
|
||||
if config.use_cuda_graphs:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
forward_fn()
|
||||
benchmark_fn = graph.replay
|
||||
else:
|
||||
benchmark_fn = forward_fn
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(config.repeats):
|
||||
@@ -785,7 +864,7 @@ def _run_single_benchmark(
|
||||
|
||||
start.record()
|
||||
for _ in range(config.num_layers):
|
||||
forward_fn()
|
||||
benchmark_fn()
|
||||
end.record()
|
||||
|
||||
torch.accelerator.synchronize()
|
||||
@@ -852,13 +931,30 @@ def _run_mla_benchmark_batched(
|
||||
# Determine if this is a sparse backend
|
||||
is_sparse = backend_cfg.get("is_sparse", False)
|
||||
|
||||
# Extract kv_cache_dtype from the first config
|
||||
kv_cache_dtype = getattr(first_config, "kv_cache_dtype", "auto")
|
||||
|
||||
# FlashMLA sparse only supports "fp8_ds_mla" internally (not generic "fp8").
|
||||
# Remap here so the user can pass --kv-cache-dtype fp8 regardless of backend.
|
||||
if backend.upper() == "FLASHMLA_SPARSE" and kv_cache_dtype == "fp8":
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
|
||||
# Compute max total_q across all configs so the metadata builder buffer
|
||||
# and scheduler config are large enough for all batch specs.
|
||||
max_total_q = max(
|
||||
sum(r.q_len for r in parse_batch_spec(cfg.batch_spec))
|
||||
for cfg, *_ in configs_with_params
|
||||
)
|
||||
|
||||
# Create and set vLLM config for MLA (reused across all benchmarks)
|
||||
vllm_config = create_minimal_vllm_config(
|
||||
model_name="deepseek-v3", # Used only for model path
|
||||
block_size=block_size,
|
||||
max_num_batched_tokens=max_total_q,
|
||||
mla_dims=mla_dims, # Use custom dims from config or default
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
prefill_backend=prefill_backend,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
results = []
|
||||
@@ -883,7 +979,9 @@ def _run_mla_benchmark_batched(
|
||||
mla_dims,
|
||||
vllm_config,
|
||||
device,
|
||||
max_num_tokens=max_total_q,
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
# Verify the actual prefill backend matches what was requested
|
||||
@@ -942,6 +1040,7 @@ def _run_mla_benchmark_batched(
|
||||
mla_dims,
|
||||
device,
|
||||
indexer=indexer,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user