[Misc] Fix up attention benchmarks (#33810)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -8,7 +8,9 @@ This module provides helpers for running standard attention backends
|
||||
(FlashAttention, Triton, FlashInfer) with real vLLM integration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -24,8 +26,13 @@ from vllm.config import (
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout,
|
||||
set_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
# ============================================================================
|
||||
@@ -37,22 +44,14 @@ _BACKEND_CONFIG = {
|
||||
"flash": {
|
||||
"module": "vllm.v1.attention.backends.flash_attn",
|
||||
"backend_class": "FlashAttentionBackend",
|
||||
"dtype": torch.float16,
|
||||
"cache_layout": "standard",
|
||||
# ^ [2, num_blocks, block_size, num_kv_heads, head_dim]
|
||||
},
|
||||
"triton": {
|
||||
"module": "vllm.v1.attention.backends.triton_attn",
|
||||
"backend_class": "TritonAttentionBackend",
|
||||
"dtype": torch.float32,
|
||||
"cache_layout": "standard",
|
||||
},
|
||||
"flashinfer": {
|
||||
"module": "vllm.v1.attention.backends.flashinfer",
|
||||
"backend_class": "FlashInferBackend",
|
||||
"dtype": torch.float16,
|
||||
"cache_layout": "flashinfer",
|
||||
# ^ [num_blocks, 2, block_size, num_kv_heads, head_dim]
|
||||
},
|
||||
}
|
||||
|
||||
@@ -66,6 +65,18 @@ def _get_backend_config(backend: str) -> dict:
|
||||
return _BACKEND_CONFIG[backend]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_warnings_and_errors_only():
|
||||
"""Temporarily set vLLM logger to WARNING level."""
|
||||
logger = logging.getLogger("vllm")
|
||||
old_level = logger.level
|
||||
logger.setLevel(logging.WARNING)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.setLevel(old_level)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metadata Building Helpers
|
||||
# ============================================================================
|
||||
@@ -88,11 +99,7 @@ def _build_common_attn_metadata(
|
||||
query_start_loc_cpu = query_start_loc.cpu()
|
||||
|
||||
seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
|
||||
context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)]
|
||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||
max_seq_len = int(seq_lens.max().item())
|
||||
|
||||
max_blocks = (max(kv_lens) + block_size - 1) // block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
@@ -107,8 +114,6 @@ def _build_common_attn_metadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=total_tokens,
|
||||
max_query_len=max_query_len,
|
||||
@@ -121,7 +126,6 @@ def _build_common_attn_metadata(
|
||||
|
||||
def _create_vllm_config(
|
||||
config: BenchmarkConfig,
|
||||
dtype: torch.dtype,
|
||||
max_num_blocks: int,
|
||||
) -> VllmConfig:
|
||||
"""Create a VllmConfig for benchmarking with mock model methods."""
|
||||
@@ -129,7 +133,7 @@ def _create_vllm_config(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
tokenizer="meta-llama/Meta-Llama-3-8B",
|
||||
trust_remote_code=False,
|
||||
dtype=dtype,
|
||||
dtype="auto", # Use model's native dtype
|
||||
seed=0,
|
||||
max_model_len=1024,
|
||||
)
|
||||
@@ -198,6 +202,7 @@ def _create_backend_impl(
|
||||
backend_cfg: dict,
|
||||
config: BenchmarkConfig,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Create backend implementation instance."""
|
||||
import importlib
|
||||
@@ -206,7 +211,6 @@ def _create_backend_impl(
|
||||
backend_class = getattr(backend_module, backend_cfg["backend_class"])
|
||||
|
||||
scale = get_attention_scale(config.head_dim)
|
||||
dtype = backend_cfg["dtype"]
|
||||
|
||||
impl = backend_class.get_impl_cls()(
|
||||
num_heads=config.num_q_heads,
|
||||
@@ -227,7 +231,7 @@ def _create_backend_impl(
|
||||
|
||||
layer = MockLayer(device, kv_cache_spec=kv_cache_spec)
|
||||
|
||||
return backend_class, impl, layer, dtype
|
||||
return backend_class, impl, layer
|
||||
|
||||
|
||||
def _create_metadata_builder(
|
||||
@@ -235,11 +239,44 @@ def _create_metadata_builder(
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
backend_name: str = "",
|
||||
):
|
||||
"""Create metadata builder instance."""
|
||||
return backend_class.get_builder_cls()(
|
||||
layer_names = ["layer_0"]
|
||||
builder_cls = backend_class.get_builder_cls()
|
||||
|
||||
# Flashinfer needs get_per_layer_parameters mocked since we don't have
|
||||
# real model layers registered
|
||||
if backend_name == "flashinfer":
|
||||
import unittest.mock
|
||||
|
||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||
|
||||
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
return {
|
||||
layer_name: PerLayerParameters(
|
||||
window_left=-1, # No sliding window
|
||||
logits_soft_cap=0.0, # No soft cap
|
||||
sm_scale=1.0 / (head_size**0.5), # Standard scale
|
||||
)
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with unittest.mock.patch(
|
||||
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
|
||||
mock_get_per_layer_parameters,
|
||||
):
|
||||
return builder_cls(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
layer_names=layer_names,
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return builder_cls(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
layer_names=["layer_0"],
|
||||
layer_names=layer_names,
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
@@ -281,39 +318,44 @@ def _create_input_tensors(
|
||||
def _create_kv_cache(
|
||||
config: BenchmarkConfig,
|
||||
max_num_blocks: int,
|
||||
cache_layout: str,
|
||||
backend_class,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> list:
|
||||
"""Create KV cache tensors for all layers."""
|
||||
if cache_layout == "flashinfer":
|
||||
# FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim]
|
||||
cache_list = [
|
||||
torch.zeros(
|
||||
max_num_blocks,
|
||||
2,
|
||||
config.block_size,
|
||||
config.num_kv_heads,
|
||||
config.head_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
else:
|
||||
# Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim]
|
||||
cache_list = [
|
||||
torch.zeros(
|
||||
2,
|
||||
max_num_blocks,
|
||||
config.block_size,
|
||||
config.num_kv_heads,
|
||||
config.head_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
"""Create KV cache tensors for all layers using the backend's methods.
|
||||
|
||||
Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order()
|
||||
to create the cache with the correct shape and memory layout.
|
||||
"""
|
||||
# Get the logical shape from the backend
|
||||
cache_shape = backend_class.get_kv_cache_shape(
|
||||
num_blocks=max_num_blocks,
|
||||
block_size=config.block_size,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
head_size=config.head_dim,
|
||||
)
|
||||
|
||||
# Get the stride order for custom memory layout
|
||||
try:
|
||||
stride_order = backend_class.get_kv_cache_stride_order()
|
||||
assert len(stride_order) == len(cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
stride_order = tuple(range(len(cache_shape)))
|
||||
|
||||
# Permute shape to physical layout order
|
||||
physical_shape = tuple(cache_shape[i] for i in stride_order)
|
||||
|
||||
# Compute inverse permutation to get back to logical view
|
||||
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
|
||||
|
||||
cache_list = []
|
||||
for _ in range(config.num_layers):
|
||||
# Allocate in physical layout order (contiguous in memory)
|
||||
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
|
||||
# Permute to logical view
|
||||
cache = cache.permute(*inv_order)
|
||||
cache_list.append(cache)
|
||||
|
||||
return cache_list
|
||||
|
||||
|
||||
@@ -418,53 +460,72 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
kv_lens = [r.kv_len for r in requests]
|
||||
total_q = sum(q_lens)
|
||||
max_kv = max(kv_lens)
|
||||
batch_size = len(q_lens)
|
||||
|
||||
max_num_blocks = (max_kv + config.block_size - 1) // config.block_size
|
||||
# Calculate total blocks needed: batch_size * max_blocks_per_request
|
||||
max_blocks_per_request = (max_kv + config.block_size - 1) // config.block_size
|
||||
max_num_blocks = batch_size * max_blocks_per_request
|
||||
|
||||
backend_class, impl, layer, dtype = _create_backend_impl(
|
||||
backend_cfg, config, device
|
||||
)
|
||||
# Suppress vLLM logs during setup to reduce spam
|
||||
with log_warnings_and_errors_only():
|
||||
# Create vllm_config first - uses model's native dtype via "auto"
|
||||
vllm_config = _create_vllm_config(config, max_num_blocks)
|
||||
dtype = vllm_config.model_config.dtype
|
||||
|
||||
common_metadata = _build_common_attn_metadata(
|
||||
q_lens, kv_lens, config.block_size, device
|
||||
)
|
||||
# Wrap everything in set_current_vllm_config context
|
||||
# This is required for backends like flashinfer that need global config
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend_class, impl, layer = _create_backend_impl(
|
||||
backend_cfg, config, device, dtype
|
||||
)
|
||||
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=config.block_size,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
head_size=config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
# Set KV cache layout if the backend requires a specific one
|
||||
# (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention)
|
||||
required_layout = backend_class.get_required_kv_cache_layout()
|
||||
if required_layout is not None:
|
||||
set_kv_cache_layout(required_layout)
|
||||
get_kv_cache_layout.cache_clear()
|
||||
|
||||
vllm_config = _create_vllm_config(config, dtype, max_num_blocks)
|
||||
common_metadata = _build_common_attn_metadata(
|
||||
q_lens, kv_lens, config.block_size, device
|
||||
)
|
||||
|
||||
builder = _create_metadata_builder(
|
||||
backend_class, kv_cache_spec, vllm_config, device
|
||||
)
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=config.block_size,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
head_size=config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
builder = _create_metadata_builder(
|
||||
backend_class, kv_cache_spec, vllm_config, device, config.backend
|
||||
)
|
||||
|
||||
q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
|
||||
cache_list = _create_kv_cache(
|
||||
config, max_num_blocks, backend_cfg["cache_layout"], device, dtype
|
||||
)
|
||||
q_list, k_list, v_list = _create_input_tensors(
|
||||
config, total_q, device, dtype
|
||||
)
|
||||
|
||||
times, mem_stats = _run_single_benchmark(
|
||||
config,
|
||||
impl,
|
||||
layer,
|
||||
q_list,
|
||||
k_list,
|
||||
v_list,
|
||||
cache_list,
|
||||
attn_metadata,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
cache_list = _create_kv_cache(
|
||||
config, max_num_blocks, backend_class, device, dtype
|
||||
)
|
||||
|
||||
times, mem_stats = _run_single_benchmark(
|
||||
config,
|
||||
impl,
|
||||
layer,
|
||||
q_list,
|
||||
k_list,
|
||||
v_list,
|
||||
cache_list,
|
||||
attn_metadata,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
|
||||
mean_time = np.mean(times)
|
||||
throughput = total_q / mean_time if mean_time > 0 else 0
|
||||
|
||||
Reference in New Issue
Block a user