diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py index 0329d1102..a8b1c5478 100644 --- a/benchmarks/attention_benchmarks/benchmark.py +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -47,6 +47,8 @@ from common import ( is_mla_backend, ) +from vllm.v1.worker.workspace import init_workspace_manager + def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: """Run standard attention benchmark (Flash/Triton/FlashInfer).""" @@ -462,7 +464,7 @@ def main(): parser.add_argument( "--batch-specs", nargs="+", - default=["q2k", "8q1s1k"], + default=None, help="Batch specifications using extended grammar", ) @@ -478,6 +480,21 @@ def main(): parser.add_argument("--repeats", type=int, default=1, help="Repetitions") parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations") parser.add_argument("--profile-memory", action="store_true", help="Profile memory") + parser.add_argument( + "--kv-cache-dtype", + default="auto", + choices=["auto", "fp8"], + help="KV cache dtype: auto or fp8", + ) + parser.add_argument( + "--cuda-graphs", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Launch kernels with CUDA graphs to eliminate CPU overhead" + "in measurements (default: True)" + ), + ) # Parameter sweep (use YAML config for advanced sweeps) parser.add_argument( @@ -536,21 +553,24 @@ def main(): # Batch specs and sizes # Support both explicit batch_specs and generated batch_spec_ranges - if "batch_spec_ranges" in yaml_config: - # Generate batch specs from ranges - generated_specs = generate_batch_specs_from_ranges( - yaml_config["batch_spec_ranges"] - ) - # Combine with any explicit batch_specs - if "batch_specs" in yaml_config: - args.batch_specs = yaml_config["batch_specs"] + generated_specs - else: - args.batch_specs = generated_specs - console.print( - f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]" - ) - elif "batch_specs" in yaml_config: - args.batch_specs = yaml_config["batch_specs"] + # CLI --batch-specs takes precedence over YAML when provided. + cli_batch_specs_provided = args.batch_specs is not None + if not cli_batch_specs_provided: + if "batch_spec_ranges" in yaml_config: + # Generate batch specs from ranges + generated_specs = generate_batch_specs_from_ranges( + yaml_config["batch_spec_ranges"] + ) + # Combine with any explicit batch_specs + if "batch_specs" in yaml_config: + args.batch_specs = yaml_config["batch_specs"] + generated_specs + else: + args.batch_specs = generated_specs + console.print( + f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]" + ) + elif "batch_specs" in yaml_config: + args.batch_specs = yaml_config["batch_specs"] if "batch_sizes" in yaml_config: args.batch_sizes = yaml_config["batch_sizes"] @@ -575,6 +595,10 @@ def main(): args.warmup_iters = yaml_config["warmup_iters"] if "profile_memory" in yaml_config: args.profile_memory = yaml_config["profile_memory"] + if "kv_cache_dtype" in yaml_config: + args.kv_cache_dtype = yaml_config["kv_cache_dtype"] + if "cuda_graphs" in yaml_config: + args.cuda_graphs = yaml_config["cuda_graphs"] # Parameter sweep configuration if "parameter_sweep" in yaml_config: @@ -629,12 +653,18 @@ def main(): # Determine backends backends = args.backends or ([args.backend] if args.backend else ["flash"]) prefill_backends = getattr(args, "prefill_backends", None) + if not args.batch_specs: + args.batch_specs = ["q2k", "8q1s1k"] console.print(f"Backends: {', '.join(backends)}") if prefill_backends: console.print(f"Prefill backends: {', '.join(prefill_backends)}") console.print(f"Batch specs: {', '.join(args.batch_specs)}") + console.print(f"KV cache dtype: {args.kv_cache_dtype}") + console.print(f"CUDA graphs: {args.cuda_graphs}") console.print() + init_workspace_manager(args.device) + # Run benchmarks all_results = [] @@ -687,6 +717,8 @@ def main(): repeats=args.repeats, warmup_iters=args.warmup_iters, profile_memory=args.profile_memory, + kv_cache_dtype=args.kv_cache_dtype, + use_cuda_graphs=args.cuda_graphs, ) # Add decode pipeline config @@ -839,6 +871,8 @@ def main(): "repeats": args.repeats, "warmup_iters": args.warmup_iters, "profile_memory": args.profile_memory, + "kv_cache_dtype": args.kv_cache_dtype, + "use_cuda_graphs": args.cuda_graphs, } all_results = run_model_parameter_sweep( backends, @@ -861,6 +895,8 @@ def main(): "repeats": args.repeats, "warmup_iters": args.warmup_iters, "profile_memory": args.profile_memory, + "kv_cache_dtype": args.kv_cache_dtype, + "use_cuda_graphs": args.cuda_graphs, } all_results = run_parameter_sweep( backends, args.batch_specs, base_config_args, args.parameter_sweep, console @@ -891,6 +927,8 @@ def main(): repeats=args.repeats, warmup_iters=args.warmup_iters, profile_memory=args.profile_memory, + kv_cache_dtype=args.kv_cache_dtype, + use_cuda_graphs=args.cuda_graphs, ) result = run_benchmark(config) diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py index 208d6273c..74d9e2397 100644 --- a/benchmarks/attention_benchmarks/common.py +++ b/benchmarks/attention_benchmarks/common.py @@ -213,6 +213,9 @@ class BenchmarkConfig: profile_memory: bool = False use_cuda_graphs: bool = False + # "auto" or "fp8" + kv_cache_dtype: str = "auto" + # MLA-specific prefill_backend: str | None = None kv_lora_rank: int | None = None @@ -369,6 +372,7 @@ class ResultsFormatter: "backend", "batch_spec", "num_layers", + "kv_cache_dtype", "mean_time", "std_time", "throughput", @@ -382,6 +386,7 @@ class ResultsFormatter: "backend": r.config.backend, "batch_spec": r.config.batch_spec, "num_layers": r.config.num_layers, + "kv_cache_dtype": r.config.kv_cache_dtype, "mean_time": r.mean_time, "std_time": r.std_time, "throughput": r.throughput_tokens_per_sec or 0, diff --git a/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml index b555d90cb..c342e9fb8 100644 --- a/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml +++ b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml @@ -30,9 +30,9 @@ batch_specs: - "2q16k_32q1s4k" # 2 very large prefill + 32 decode # Context extension + decode - - "2q1kkv2k_16q1s1k" # 2 extend + 16 decode - - "4q2kkv4k_32q1s2k" # 4 extend + 32 decode - - "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode + - "2q1ks2k_16q1s1k" # 2 extend + 16 decode + - "4q2ks4k_32q1s2k" # 4 extend + 32 decode + - "2q1ks8k_32q1s2k" # 2 large extend + 32 decode # Explicitly chunked prefill - "q8k" # 8k prefill with chunking hint diff --git a/benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml b/benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml new file mode 100644 index 000000000..689c9f3c3 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml @@ -0,0 +1,58 @@ +# MLA decode-only benchmark configuration + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 # Base value, can be swept for TP simulation + num_kv_heads: 1 # MLA uses single latent KV + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128 + +# Model parameter sweep: simulate tensor parallelism by varying num_q_heads +# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads +model_parameter_sweep: + param_name: "num_q_heads" + values: [128, 64, 32, 16] + label_format: "{backend}_{value}h" + +batch_specs: + # Small batches, varying sequence lengths + - "16q1s512" # 16 requests, 512 KV cache + - "16q1s1k" # 16 requests, 1k KV cache + - "16q1s2k" # 16 requests, 2k KV cache + - "16q1s4k" # 16 requests, 4k KV cache + + # Medium batches + - "32q1s1k" # 32 requests, 1k KV cache + - "32q1s2k" # 32 requests, 2k KV cache + - "32q1s4k" # 32 requests, 4k KV cache + - "32q1s8k" # 32 requests, 8k KV cache + + # Large batches + - "64q1s1k" # 64 requests, 1k KV cache + - "64q1s2k" # 64 requests, 2k KV cache + - "64q1s4k" # 64 requests, 4k KV cache + - "64q1s8k" # 64 requests, 8k KV cache + + # Very large batches + - "128q1s1k" # 128 requests, 1k KV cache + - "128q1s2k" # 128 requests, 2k KV cache + - "128q1s4k" # 128 requests, 4k KV cache + - "128q1s8k" # 128 requests, 8k KV cache + + # Long context + - "32q1s16k" # 32 requests, 16k KV cache + - "32q1s32k" # 32 requests, 32k KV cache + +backends: + - FLASHMLA_SPARSE + - FLASHINFER_MLA_SPARSE + +device: "cuda:0" +repeats: 100 +warmup_iters: 10 +profile_memory: true diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index 0d612e374..f8bc7b4a1 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -60,9 +60,11 @@ def create_minimal_vllm_config( model_name: str = "deepseek-v3", block_size: int = 128, max_num_seqs: int = 256, + max_num_batched_tokens: int = 8192, mla_dims: dict | None = None, index_topk: int | None = None, prefill_backend: str | None = None, + kv_cache_dtype: str = "auto", ) -> VllmConfig: """ Create minimal VllmConfig for MLA benchmarks. @@ -149,13 +151,13 @@ def create_minimal_vllm_config( cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, - cache_dtype="auto", + cache_dtype=kv_cache_dtype, enable_prefix_caching=False, ) scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, - max_num_batched_tokens=8192, + max_num_batched_tokens=max(max_num_batched_tokens, max_num_seqs), max_model_len=32768, is_encoder_decoder=False, enable_chunked_prefill=True, @@ -535,6 +537,7 @@ def _create_backend_impl( device: torch.device, max_num_tokens: int = 8192, index_topk: int | None = None, + kv_cache_dtype: str = "auto", ): """ Create backend implementation instance. @@ -583,7 +586,7 @@ def _create_backend_impl( "num_kv_heads": mla_dims["num_kv_heads"], "alibi_slopes": None, "sliding_window": None, - "kv_cache_dtype": "auto", + "kv_cache_dtype": kv_cache_dtype, "logits_soft_cap": None, "attn_type": "decoder", "kv_sharing_target_layer_name": None, @@ -701,6 +704,7 @@ def _run_single_benchmark( mla_dims: dict, device: torch.device, indexer=None, + kv_cache_dtype: str | None = None, ) -> BenchmarkResult: """ Run a single benchmark iteration. @@ -734,49 +738,124 @@ def _run_single_benchmark( ) # Create KV cache - kv_cache = torch.zeros( - num_blocks, - block_size, - mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"], - device=device, - dtype=torch.bfloat16, - ) + if kv_cache_dtype is None: + kv_cache_dtype = getattr(config, "kv_cache_dtype", "auto") + head_size = mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"] + if kv_cache_dtype == "fp8_ds_mla": + # FlashMLA sparse custom format: 656 bytes per token, stored as uint8. + # Layout: kv_lora_rank fp8 bytes + 4 float32 tile scales + # + 2*rope_dim bf16 bytes + # = 512 + 16 + 128 = 656 bytes for DeepSeek dims. + kv_cache = torch.zeros( + num_blocks, + block_size, + 656, + device=device, + dtype=torch.uint8, + ) + elif kv_cache_dtype == "fp8": + from vllm.platforms import current_platform - # Create input tensors for both decode and prefill modes - decode_inputs, prefill_inputs = _create_input_tensors( - total_q, - mla_dims, - backend_cfg["query_format"], - device, - torch.bfloat16, - ) + kv_cache = torch.zeros( + num_blocks, + block_size, + head_size, + device=device, + dtype=torch.uint8, + ).view(current_platform.fp8_dtype()) + else: + kv_cache = torch.zeros( + num_blocks, + block_size, + head_size, + device=device, + dtype=torch.bfloat16, + ) # Fill indexer with random indices for sparse backends is_sparse = backend_cfg.get("is_sparse", False) if is_sparse and indexer is not None: indexer.fill_random_indices(total_q, max_kv_len) - # Determine which forward method to use based on metadata - if metadata.decode is not None: - forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer) - elif metadata.prefill is not None: - forward_fn = lambda: impl.forward_mha( - prefill_inputs["q"], - prefill_inputs["k_c_normed"], - prefill_inputs["k_pe"], - kv_cache, - metadata, - prefill_inputs["k_scale"], - prefill_inputs["output"], - ) - else: + # Determine which forward methods to use based on metadata. + # Sparse MLA backends always use forward_mqa + has_decode = is_sparse or getattr(metadata, "decode", None) is not None + has_prefill = not is_sparse and getattr(metadata, "prefill", None) is not None + if not has_decode and not has_prefill: raise RuntimeError("Metadata has neither decode nor prefill metadata") + num_decode = ( + metadata.num_decode_tokens + if (has_decode and has_prefill) + else total_q + if has_decode + else 0 + ) + num_prefill = total_q - num_decode + + # Some backends requires fp8 queries when using fp8 KV cache. + is_fp8_kvcache = kv_cache_dtype.startswith("fp8") + quantize_query = is_fp8_kvcache and getattr( + impl, "supports_quant_query_input", False + ) + + # quantize_query forces concat format + query_fmt = "concat" if quantize_query else backend_cfg["query_format"] + + # Create decode query tensors + if has_decode: + decode_inputs, _ = _create_input_tensors( + num_decode, mla_dims, query_fmt, device, torch.bfloat16 + ) + # Cast decode query to fp8 if the backend supports it + if quantize_query: + from vllm.platforms import current_platform + + if isinstance(decode_inputs, tuple): + decode_inputs = torch.cat(list(decode_inputs), dim=-1) + decode_inputs = decode_inputs.to(current_platform.fp8_dtype()) + + # Create prefill input tensors + if has_prefill: + _, prefill_inputs = _create_input_tensors( + num_prefill, mla_dims, query_fmt, device, torch.bfloat16 + ) + + # Build forward function + def forward_fn(): + results = [] + if has_decode: + results.append(impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)) + if has_prefill: + results.append( + impl.forward_mha( + prefill_inputs["q"], + prefill_inputs["k_c_normed"], + prefill_inputs["k_pe"], + kv_cache, + metadata, + prefill_inputs["k_scale"], + prefill_inputs["output"], + ) + ) + return results[0] if len(results) == 1 else tuple(results) + # Warmup for _ in range(config.warmup_iters): forward_fn() torch.accelerator.synchronize() + # Optionally capture a CUDA graph after warmup. + # Graph replay eliminates CPU launch overhead so timings reflect pure + # kernel time. + if config.use_cuda_graphs: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + forward_fn() + benchmark_fn = graph.replay + else: + benchmark_fn = forward_fn + # Benchmark times = [] for _ in range(config.repeats): @@ -785,7 +864,7 @@ def _run_single_benchmark( start.record() for _ in range(config.num_layers): - forward_fn() + benchmark_fn() end.record() torch.accelerator.synchronize() @@ -852,13 +931,30 @@ def _run_mla_benchmark_batched( # Determine if this is a sparse backend is_sparse = backend_cfg.get("is_sparse", False) + # Extract kv_cache_dtype from the first config + kv_cache_dtype = getattr(first_config, "kv_cache_dtype", "auto") + + # FlashMLA sparse only supports "fp8_ds_mla" internally (not generic "fp8"). + # Remap here so the user can pass --kv-cache-dtype fp8 regardless of backend. + if backend.upper() == "FLASHMLA_SPARSE" and kv_cache_dtype == "fp8": + kv_cache_dtype = "fp8_ds_mla" + + # Compute max total_q across all configs so the metadata builder buffer + # and scheduler config are large enough for all batch specs. + max_total_q = max( + sum(r.q_len for r in parse_batch_spec(cfg.batch_spec)) + for cfg, *_ in configs_with_params + ) + # Create and set vLLM config for MLA (reused across all benchmarks) vllm_config = create_minimal_vllm_config( model_name="deepseek-v3", # Used only for model path block_size=block_size, + max_num_batched_tokens=max_total_q, mla_dims=mla_dims, # Use custom dims from config or default index_topk=index_topk if is_sparse else None, prefill_backend=prefill_backend, + kv_cache_dtype=kv_cache_dtype, ) results = [] @@ -883,7 +979,9 @@ def _run_mla_benchmark_batched( mla_dims, vllm_config, device, + max_num_tokens=max_total_q, index_topk=index_topk if is_sparse else None, + kv_cache_dtype=kv_cache_dtype, ) # Verify the actual prefill backend matches what was requested @@ -942,6 +1040,7 @@ def _run_mla_benchmark_batched( mla_dims, device, indexer=indexer, + kv_cache_dtype=kv_cache_dtype, ) results.append(result) diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py index 6af56e0e9..aa636cd9c 100644 --- a/benchmarks/attention_benchmarks/runner.py +++ b/benchmarks/attention_benchmarks/runner.py @@ -140,7 +140,7 @@ def _create_vllm_config( cache_config = CacheConfig( block_size=config.block_size, - cache_dtype="auto", + cache_dtype=config.kv_cache_dtype, ) cache_config.num_gpu_blocks = max_num_blocks cache_config.num_cpu_blocks = 0 @@ -215,7 +215,7 @@ def _create_backend_impl( num_kv_heads=config.num_kv_heads, alibi_slopes=None, sliding_window=None, - kv_cache_dtype="auto", + kv_cache_dtype=config.kv_cache_dtype, ) kv_cache_spec = FullAttentionSpec( @@ -288,12 +288,22 @@ def _create_input_tensors( total_q: int, device: torch.device, dtype: torch.dtype, + quantize_query: bool = False, ) -> tuple: - """Create Q, K, V input tensors for all layers.""" + """Create Q, K, V input tensors for all layers. + + When quantize_query is True, queries are cast to fp8 to match backends + that require query/key/value dtype consistency. + """ + q_dtype = dtype + if quantize_query: + from vllm.platforms import current_platform + + q_dtype = current_platform.fp8_dtype() q_list = [ torch.randn( total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype - ) + ).to(q_dtype) for _ in range(config.num_layers) ] k_list = [ @@ -344,10 +354,17 @@ def _create_kv_cache( # Compute inverse permutation to get back to logical view inv_order = [stride_order.index(i) for i in range(len(stride_order))] + # Use fp8 dtype for cache when requested. + cache_dtype = dtype + if config.kv_cache_dtype == "fp8": + from vllm.platforms import current_platform + + cache_dtype = current_platform.fp8_dtype() + cache_list = [] for _ in range(config.num_layers): # Allocate in physical layout order (contiguous in memory) - cache = torch.zeros(*physical_shape, device=device, dtype=dtype) + cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype) # Permute to logical view cache = cache.permute(*inv_order) cache_list.append(cache) @@ -392,6 +409,37 @@ def _run_single_benchmark( ) torch.accelerator.synchronize() + # Optionally capture a CUDA graph after warmup. + # Graph replay eliminates CPU launch overhead so timings reflect pure + # kernel time. + if config.use_cuda_graphs: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for i in range(config.num_layers): + impl.forward( + layer, + q_list[i], + k_list[i], + v_list[i], + cache_list[i], + attn_metadata, + output=out, + ) + benchmark_fn = graph.replay + else: + + def benchmark_fn(): + for i in range(config.num_layers): + impl.forward( + layer, + q_list[i], + k_list[i], + v_list[i], + cache_list[i], + attn_metadata, + output=out, + ) + # Benchmark times = [] for _ in range(config.repeats): @@ -399,16 +447,7 @@ def _run_single_benchmark( end = torch.cuda.Event(enable_timing=True) start.record() - for i in range(config.num_layers): - impl.forward( - layer, - q_list[i], - k_list[i], - v_list[i], - cache_list[i], - attn_metadata, - output=out, - ) + benchmark_fn() end.record() torch.accelerator.synchronize() @@ -502,8 +541,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: common_attn_metadata=common_metadata, ) + # Only quantize queries when the impl supports it + quantize_query = config.kv_cache_dtype.startswith("fp8") and getattr( + impl, "supports_quant_query_input", False + ) q_list, k_list, v_list = _create_input_tensors( - config, total_q, device, dtype + config, total_q, device, dtype, quantize_query=quantize_query ) cache_list = _create_kv_cache(