diff --git a/.buildkite/test_areas/benchmarks.yaml b/.buildkite/test_areas/benchmarks.yaml index 574b642d4..57080c46f 100644 --- a/.buildkite/test_areas/benchmarks.yaml +++ b/.buildkite/test_areas/benchmarks.yaml @@ -17,3 +17,14 @@ steps: - tests/benchmarks/ commands: - pytest -v -s benchmarks/ + +- label: Attention Benchmarks Smoke Test (B200) + device: b200 + num_gpus: 2 + optional: true + timeout_in_minutes: 10 + source_file_dependencies: + - benchmarks/attention_benchmarks/ + - vllm/v1/attention/ + commands: + - python benchmarks/attention_benchmarks/benchmark.py --backends flash flashinfer --batch-specs "8q1s1k" --repeats 1 --warmup-iters 1 diff --git a/benchmarks/attention_benchmarks/batch_spec.py b/benchmarks/attention_benchmarks/batch_spec.py index 41681796e..9f15f1d80 100644 --- a/benchmarks/attention_benchmarks/batch_spec.py +++ b/benchmarks/attention_benchmarks/batch_spec.py @@ -229,3 +229,40 @@ def get_batch_stats(requests: list[BatchRequest]) -> dict: sum(r.kv_len for r in requests) / len(requests) if requests else 0 ), } + + +def get_batch_type(batch_spec: str, spec_decode_threshold: int = 8) -> str: + """ + Classify a batch spec into a type string. + + Args: + batch_spec: Batch specification string (e.g., "q2k", "8q1s1k", "2q2k_8q1s1k") + spec_decode_threshold: Max q_len to be considered spec-decode vs extend + + Returns: + Type string: "prefill", "decode", "spec-decode", "extend", or "mixed (types...)" + """ + requests = parse_batch_spec(batch_spec) + + # Classify each request + types_present = set() + for req in requests: + if req.is_decode: + types_present.add("decode") + elif req.is_prefill: + types_present.add("prefill") + elif req.is_extend: + # Distinguish spec-decode (small q_len) from extend (chunked prefill) + if req.q_len <= spec_decode_threshold: + types_present.add("spec-decode") + else: + types_present.add("extend") + + if len(types_present) == 1: + return types_present.pop() + elif len(types_present) > 1: + # Sort for consistent output + sorted_types = sorted(types_present) + return f"mixed ({'+'.join(sorted_types)})" + else: + return "unknown" diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py index 7155bdc3f..190b2f977 100644 --- a/benchmarks/attention_benchmarks/common.py +++ b/benchmarks/attention_benchmarks/common.py @@ -12,6 +12,7 @@ from typing import Any import numpy as np import torch +from batch_spec import get_batch_type, parse_batch_spec from rich.console import Console from rich.table import Table @@ -316,12 +317,14 @@ class ResultsFormatter: backends: List of backend names being compared compare_to_fastest: Show percentage comparison to fastest """ - # Group by batch spec + # Group by batch spec, preserving first-occurrence order by_spec = {} + specs_order = [] for r in results: spec = r.config.batch_spec if spec not in by_spec: by_spec[spec] = {} + specs_order.append(spec) by_spec[spec][r.config.backend] = r # Create shortened backend names for display @@ -337,6 +340,8 @@ class ResultsFormatter: table = Table(title="Attention Benchmark Results") table.add_column("Batch\nSpec", no_wrap=True) + table.add_column("Type", no_wrap=True) + table.add_column("Batch\nSize", justify="right", no_wrap=True) multi = len(backends) > 1 for backend in backends: @@ -350,12 +355,14 @@ class ResultsFormatter: table.add_column(col_rel, justify="right", no_wrap=False) # Add rows - for spec in sorted(by_spec.keys()): + for spec in specs_order: spec_results = by_spec[spec] times = {b: r.mean_time for b, r in spec_results.items() if r.success} best_time = min(times.values()) if times else 0.0 - row = [spec] + batch_type = get_batch_type(spec) + batch_size = len(parse_batch_spec(spec)) + row = [spec, batch_type, str(batch_size)] for backend in backends: if backend in spec_results: r = spec_results[backend] diff --git a/benchmarks/attention_benchmarks/configs/standard_attention.yaml b/benchmarks/attention_benchmarks/configs/standard_attention.yaml index c0bdb98fb..591db6837 100644 --- a/benchmarks/attention_benchmarks/configs/standard_attention.yaml +++ b/benchmarks/attention_benchmarks/configs/standard_attention.yaml @@ -25,10 +25,18 @@ batch_specs: - "4q1k_16q1s2k" # 4 prefill + 16 decode - "2q4k_32q1s1k" # 2 large prefill + 32 decode - # Context extension - - "q1ks2k" # 1k query, 2k sequence (chunked prefill) + # Speculative decode (q <= 8) + - "16q2s1k" # 16 requests, 2 spec tokens, 1k KV cache + - "16q4s1k" # 16 requests, 4 spec tokens, 1k KV cache + - "16q8s1k" # 16 requests, 8 spec tokens, 1k KV cache + - "32q4s2k" # 32 requests, 4 spec tokens, 2k KV cache + - "8q8s4k" # 8 requests, 8 spec tokens, 4k KV cache + + # Context extension (chunked prefill) + - "q1ks2k" # 1k query, 2k sequence - "2q1ks4k" # 2 requests: 1k query, 4k sequence +# Available backends: flash, triton, flashinfer backends: - flash - triton diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py index bf08a1550..79bfca681 100644 --- a/benchmarks/attention_benchmarks/runner.py +++ b/benchmarks/attention_benchmarks/runner.py @@ -8,7 +8,9 @@ This module provides helpers for running standard attention backends (FlashAttention, Triton, FlashInfer) with real vLLM integration. """ +import logging import types +from contextlib import contextmanager import numpy as np import torch @@ -24,8 +26,13 @@ from vllm.config import ( ParallelConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config, +) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + get_kv_cache_layout, + set_kv_cache_layout, ) -from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec # ============================================================================ @@ -37,22 +44,14 @@ _BACKEND_CONFIG = { "flash": { "module": "vllm.v1.attention.backends.flash_attn", "backend_class": "FlashAttentionBackend", - "dtype": torch.float16, - "cache_layout": "standard", - # ^ [2, num_blocks, block_size, num_kv_heads, head_dim] }, "triton": { "module": "vllm.v1.attention.backends.triton_attn", "backend_class": "TritonAttentionBackend", - "dtype": torch.float32, - "cache_layout": "standard", }, "flashinfer": { "module": "vllm.v1.attention.backends.flashinfer", "backend_class": "FlashInferBackend", - "dtype": torch.float16, - "cache_layout": "flashinfer", - # ^ [num_blocks, 2, block_size, num_kv_heads, head_dim] }, } @@ -66,6 +65,18 @@ def _get_backend_config(backend: str) -> dict: return _BACKEND_CONFIG[backend] +@contextmanager +def log_warnings_and_errors_only(): + """Temporarily set vLLM logger to WARNING level.""" + logger = logging.getLogger("vllm") + old_level = logger.level + logger.setLevel(logging.WARNING) + try: + yield + finally: + logger.setLevel(old_level) + + # ============================================================================ # Metadata Building Helpers # ============================================================================ @@ -88,11 +99,7 @@ def _build_common_attn_metadata( query_start_loc_cpu = query_start_loc.cpu() seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device) - seq_lens_cpu = seq_lens.cpu() - max_seq_len = int(seq_lens_cpu.max()) - - context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)] - num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + max_seq_len = int(seq_lens.max().item()) max_blocks = (max(kv_lens) + block_size - 1) // block_size num_blocks = batch_size * max_blocks @@ -107,8 +114,6 @@ def _build_common_attn_metadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, - seq_lens_cpu=seq_lens_cpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=batch_size, num_actual_tokens=total_tokens, max_query_len=max_query_len, @@ -121,7 +126,6 @@ def _build_common_attn_metadata( def _create_vllm_config( config: BenchmarkConfig, - dtype: torch.dtype, max_num_blocks: int, ) -> VllmConfig: """Create a VllmConfig for benchmarking with mock model methods.""" @@ -129,7 +133,7 @@ def _create_vllm_config( model="meta-llama/Meta-Llama-3-8B", tokenizer="meta-llama/Meta-Llama-3-8B", trust_remote_code=False, - dtype=dtype, + dtype="auto", # Use model's native dtype seed=0, max_model_len=1024, ) @@ -198,6 +202,7 @@ def _create_backend_impl( backend_cfg: dict, config: BenchmarkConfig, device: torch.device, + dtype: torch.dtype, ): """Create backend implementation instance.""" import importlib @@ -206,7 +211,6 @@ def _create_backend_impl( backend_class = getattr(backend_module, backend_cfg["backend_class"]) scale = get_attention_scale(config.head_dim) - dtype = backend_cfg["dtype"] impl = backend_class.get_impl_cls()( num_heads=config.num_q_heads, @@ -227,7 +231,7 @@ def _create_backend_impl( layer = MockLayer(device, kv_cache_spec=kv_cache_spec) - return backend_class, impl, layer, dtype + return backend_class, impl, layer def _create_metadata_builder( @@ -235,11 +239,44 @@ def _create_metadata_builder( kv_cache_spec: FullAttentionSpec, vllm_config: VllmConfig, device: torch.device, + backend_name: str = "", ): """Create metadata builder instance.""" - return backend_class.get_builder_cls()( + layer_names = ["layer_0"] + builder_cls = backend_class.get_builder_cls() + + # Flashinfer needs get_per_layer_parameters mocked since we don't have + # real model layers registered + if backend_name == "flashinfer": + import unittest.mock + + from vllm.v1.attention.backends.utils import PerLayerParameters + + def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): + head_size = vllm_config.model_config.get_head_size() + return { + layer_name: PerLayerParameters( + window_left=-1, # No sliding window + logits_soft_cap=0.0, # No soft cap + sm_scale=1.0 / (head_size**0.5), # Standard scale + ) + for layer_name in layer_names + } + + with unittest.mock.patch( + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + mock_get_per_layer_parameters, + ): + return builder_cls( + kv_cache_spec=kv_cache_spec, + layer_names=layer_names, + vllm_config=vllm_config, + device=device, + ) + + return builder_cls( kv_cache_spec=kv_cache_spec, - layer_names=["layer_0"], + layer_names=layer_names, vllm_config=vllm_config, device=device, ) @@ -281,39 +318,44 @@ def _create_input_tensors( def _create_kv_cache( config: BenchmarkConfig, max_num_blocks: int, - cache_layout: str, + backend_class, device: torch.device, dtype: torch.dtype, ) -> list: - """Create KV cache tensors for all layers.""" - if cache_layout == "flashinfer": - # FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim] - cache_list = [ - torch.zeros( - max_num_blocks, - 2, - config.block_size, - config.num_kv_heads, - config.head_dim, - device=device, - dtype=dtype, - ) - for _ in range(config.num_layers) - ] - else: - # Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim] - cache_list = [ - torch.zeros( - 2, - max_num_blocks, - config.block_size, - config.num_kv_heads, - config.head_dim, - device=device, - dtype=dtype, - ) - for _ in range(config.num_layers) - ] + """Create KV cache tensors for all layers using the backend's methods. + + Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order() + to create the cache with the correct shape and memory layout. + """ + # Get the logical shape from the backend + cache_shape = backend_class.get_kv_cache_shape( + num_blocks=max_num_blocks, + block_size=config.block_size, + num_kv_heads=config.num_kv_heads, + head_size=config.head_dim, + ) + + # Get the stride order for custom memory layout + try: + stride_order = backend_class.get_kv_cache_stride_order() + assert len(stride_order) == len(cache_shape) + except (AttributeError, NotImplementedError): + stride_order = tuple(range(len(cache_shape))) + + # Permute shape to physical layout order + physical_shape = tuple(cache_shape[i] for i in stride_order) + + # Compute inverse permutation to get back to logical view + inv_order = [stride_order.index(i) for i in range(len(stride_order))] + + cache_list = [] + for _ in range(config.num_layers): + # Allocate in physical layout order (contiguous in memory) + cache = torch.zeros(*physical_shape, device=device, dtype=dtype) + # Permute to logical view + cache = cache.permute(*inv_order) + cache_list.append(cache) + return cache_list @@ -418,53 +460,72 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: kv_lens = [r.kv_len for r in requests] total_q = sum(q_lens) max_kv = max(kv_lens) + batch_size = len(q_lens) - max_num_blocks = (max_kv + config.block_size - 1) // config.block_size + # Calculate total blocks needed: batch_size * max_blocks_per_request + max_blocks_per_request = (max_kv + config.block_size - 1) // config.block_size + max_num_blocks = batch_size * max_blocks_per_request - backend_class, impl, layer, dtype = _create_backend_impl( - backend_cfg, config, device - ) + # Suppress vLLM logs during setup to reduce spam + with log_warnings_and_errors_only(): + # Create vllm_config first - uses model's native dtype via "auto" + vllm_config = _create_vllm_config(config, max_num_blocks) + dtype = vllm_config.model_config.dtype - common_metadata = _build_common_attn_metadata( - q_lens, kv_lens, config.block_size, device - ) + # Wrap everything in set_current_vllm_config context + # This is required for backends like flashinfer that need global config + with set_current_vllm_config(vllm_config): + backend_class, impl, layer = _create_backend_impl( + backend_cfg, config, device, dtype + ) - kv_cache_spec = FullAttentionSpec( - block_size=config.block_size, - num_kv_heads=config.num_kv_heads, - head_size=config.head_dim, - dtype=dtype, - ) + # Set KV cache layout if the backend requires a specific one + # (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention) + required_layout = backend_class.get_required_kv_cache_layout() + if required_layout is not None: + set_kv_cache_layout(required_layout) + get_kv_cache_layout.cache_clear() - vllm_config = _create_vllm_config(config, dtype, max_num_blocks) + common_metadata = _build_common_attn_metadata( + q_lens, kv_lens, config.block_size, device + ) - builder = _create_metadata_builder( - backend_class, kv_cache_spec, vllm_config, device - ) + kv_cache_spec = FullAttentionSpec( + block_size=config.block_size, + num_kv_heads=config.num_kv_heads, + head_size=config.head_dim, + dtype=dtype, + ) - attn_metadata = builder.build( - common_prefix_len=0, - common_attn_metadata=common_metadata, - ) + builder = _create_metadata_builder( + backend_class, kv_cache_spec, vllm_config, device, config.backend + ) - q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_metadata, + ) - cache_list = _create_kv_cache( - config, max_num_blocks, backend_cfg["cache_layout"], device, dtype - ) + q_list, k_list, v_list = _create_input_tensors( + config, total_q, device, dtype + ) - times, mem_stats = _run_single_benchmark( - config, - impl, - layer, - q_list, - k_list, - v_list, - cache_list, - attn_metadata, - device, - dtype, - ) + cache_list = _create_kv_cache( + config, max_num_blocks, backend_class, device, dtype + ) + + times, mem_stats = _run_single_benchmark( + config, + impl, + layer, + q_list, + k_list, + v_list, + cache_list, + attn_metadata, + device, + dtype, + ) mean_time = np.mean(times) throughput = total_q / mean_time if mean_time > 0 else 0