[Benchmark] Improvements to attention benchmark script (#37115)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
@@ -140,7 +140,7 @@ def _create_vllm_config(
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=config.block_size,
|
||||
cache_dtype="auto",
|
||||
cache_dtype=config.kv_cache_dtype,
|
||||
)
|
||||
cache_config.num_gpu_blocks = max_num_blocks
|
||||
cache_config.num_cpu_blocks = 0
|
||||
@@ -215,7 +215,7 @@ def _create_backend_impl(
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
kv_cache_dtype=config.kv_cache_dtype,
|
||||
)
|
||||
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
@@ -288,12 +288,22 @@ def _create_input_tensors(
|
||||
total_q: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
quantize_query: bool = False,
|
||||
) -> tuple:
|
||||
"""Create Q, K, V input tensors for all layers."""
|
||||
"""Create Q, K, V input tensors for all layers.
|
||||
|
||||
When quantize_query is True, queries are cast to fp8 to match backends
|
||||
that require query/key/value dtype consistency.
|
||||
"""
|
||||
q_dtype = dtype
|
||||
if quantize_query:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
q_dtype = current_platform.fp8_dtype()
|
||||
q_list = [
|
||||
torch.randn(
|
||||
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
|
||||
)
|
||||
).to(q_dtype)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
k_list = [
|
||||
@@ -344,10 +354,17 @@ def _create_kv_cache(
|
||||
# Compute inverse permutation to get back to logical view
|
||||
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
|
||||
|
||||
# Use fp8 dtype for cache when requested.
|
||||
cache_dtype = dtype
|
||||
if config.kv_cache_dtype == "fp8":
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
cache_dtype = current_platform.fp8_dtype()
|
||||
|
||||
cache_list = []
|
||||
for _ in range(config.num_layers):
|
||||
# Allocate in physical layout order (contiguous in memory)
|
||||
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
|
||||
cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype)
|
||||
# Permute to logical view
|
||||
cache = cache.permute(*inv_order)
|
||||
cache_list.append(cache)
|
||||
@@ -392,6 +409,37 @@ def _run_single_benchmark(
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
# Optionally capture a CUDA graph after warmup.
|
||||
# Graph replay eliminates CPU launch overhead so timings reflect pure
|
||||
# kernel time.
|
||||
if config.use_cuda_graphs:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for i in range(config.num_layers):
|
||||
impl.forward(
|
||||
layer,
|
||||
q_list[i],
|
||||
k_list[i],
|
||||
v_list[i],
|
||||
cache_list[i],
|
||||
attn_metadata,
|
||||
output=out,
|
||||
)
|
||||
benchmark_fn = graph.replay
|
||||
else:
|
||||
|
||||
def benchmark_fn():
|
||||
for i in range(config.num_layers):
|
||||
impl.forward(
|
||||
layer,
|
||||
q_list[i],
|
||||
k_list[i],
|
||||
v_list[i],
|
||||
cache_list[i],
|
||||
attn_metadata,
|
||||
output=out,
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(config.repeats):
|
||||
@@ -399,16 +447,7 @@ def _run_single_benchmark(
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
for i in range(config.num_layers):
|
||||
impl.forward(
|
||||
layer,
|
||||
q_list[i],
|
||||
k_list[i],
|
||||
v_list[i],
|
||||
cache_list[i],
|
||||
attn_metadata,
|
||||
output=out,
|
||||
)
|
||||
benchmark_fn()
|
||||
end.record()
|
||||
|
||||
torch.accelerator.synchronize()
|
||||
@@ -502,8 +541,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
|
||||
# Only quantize queries when the impl supports it
|
||||
quantize_query = config.kv_cache_dtype.startswith("fp8") and getattr(
|
||||
impl, "supports_quant_query_input", False
|
||||
)
|
||||
q_list, k_list, v_list = _create_input_tensors(
|
||||
config, total_q, device, dtype
|
||||
config, total_q, device, dtype, quantize_query=quantize_query
|
||||
)
|
||||
|
||||
cache_list = _create_kv_cache(
|
||||
|
||||
Reference in New Issue
Block a user