[Benchmark] Improvements to attention benchmark script (#37115)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
Wei Zhao
2026-03-16 18:22:40 -04:00
committed by GitHub
parent e5b807607c
commit a3a51d20e7
6 changed files with 311 additions and 68 deletions

View File

@@ -140,7 +140,7 @@ def _create_vllm_config(
cache_config = CacheConfig(
block_size=config.block_size,
cache_dtype="auto",
cache_dtype=config.kv_cache_dtype,
)
cache_config.num_gpu_blocks = max_num_blocks
cache_config.num_cpu_blocks = 0
@@ -215,7 +215,7 @@ def _create_backend_impl(
num_kv_heads=config.num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
kv_cache_dtype=config.kv_cache_dtype,
)
kv_cache_spec = FullAttentionSpec(
@@ -288,12 +288,22 @@ def _create_input_tensors(
total_q: int,
device: torch.device,
dtype: torch.dtype,
quantize_query: bool = False,
) -> tuple:
"""Create Q, K, V input tensors for all layers."""
"""Create Q, K, V input tensors for all layers.
When quantize_query is True, queries are cast to fp8 to match backends
that require query/key/value dtype consistency.
"""
q_dtype = dtype
if quantize_query:
from vllm.platforms import current_platform
q_dtype = current_platform.fp8_dtype()
q_list = [
torch.randn(
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
)
).to(q_dtype)
for _ in range(config.num_layers)
]
k_list = [
@@ -344,10 +354,17 @@ def _create_kv_cache(
# Compute inverse permutation to get back to logical view
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
# Use fp8 dtype for cache when requested.
cache_dtype = dtype
if config.kv_cache_dtype == "fp8":
from vllm.platforms import current_platform
cache_dtype = current_platform.fp8_dtype()
cache_list = []
for _ in range(config.num_layers):
# Allocate in physical layout order (contiguous in memory)
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype)
# Permute to logical view
cache = cache.permute(*inv_order)
cache_list.append(cache)
@@ -392,6 +409,37 @@ def _run_single_benchmark(
)
torch.accelerator.synchronize()
# Optionally capture a CUDA graph after warmup.
# Graph replay eliminates CPU launch overhead so timings reflect pure
# kernel time.
if config.use_cuda_graphs:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
benchmark_fn = graph.replay
else:
def benchmark_fn():
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
# Benchmark
times = []
for _ in range(config.repeats):
@@ -399,16 +447,7 @@ def _run_single_benchmark(
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
benchmark_fn()
end.record()
torch.accelerator.synchronize()
@@ -502,8 +541,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
common_attn_metadata=common_metadata,
)
# Only quantize queries when the impl supports it
quantize_query = config.kv_cache_dtype.startswith("fp8") and getattr(
impl, "supports_quant_query_input", False
)
q_list, k_list, v_list = _create_input_tensors(
config, total_q, device, dtype
config, total_q, device, dtype, quantize_query=quantize_query
)
cache_list = _create_kv_cache(