1002 lines
37 KiB
Python
1002 lines
37 KiB
Python
#!/usr/bin/env python3
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
"""
|
|
Universal vLLM Attention Benchmark
|
|
|
|
Benchmark any attention backend with the extended grammar.
|
|
Supports standard attention (Flash/Triton/FlashInfer) and MLA backends.
|
|
|
|
Examples:
|
|
# Standard attention
|
|
python benchmark.py --backends flash flashinfer --batch-specs "q2k" "8q1s1k"
|
|
|
|
# MLA backends
|
|
python benchmark.py --backends cutlass_mla flashinfer_mla --batch-specs "64q1s1k"
|
|
|
|
# Parameter sweep (CLI)
|
|
python benchmark.py --backend cutlass_mla \
|
|
--batch-specs "64q1s1k" \
|
|
--sweep-param num_kv_splits \
|
|
--sweep-values 1 4 8 16
|
|
|
|
# Parameter sweep (YAML config - recommended)
|
|
python benchmark.py --config configs/cutlass_numsplits.yaml
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from dataclasses import replace
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from rich.console import Console
|
|
from tqdm import tqdm
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from batch_spec import parse_batch_spec
|
|
from common import (
|
|
BenchmarkConfig,
|
|
BenchmarkResult,
|
|
ModelParameterSweep,
|
|
ParameterSweep,
|
|
ResultsFormatter,
|
|
batch_spec_sort_key,
|
|
is_mla_backend,
|
|
)
|
|
|
|
from vllm.v1.worker.workspace import init_workspace_manager
|
|
|
|
|
|
def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
|
"""Run standard attention benchmark (Flash/Triton/FlashInfer)."""
|
|
from runner import run_attention_benchmark
|
|
|
|
return run_attention_benchmark(config)
|
|
|
|
|
|
def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
|
|
"""Run MLA benchmark with appropriate backend."""
|
|
from mla_runner import run_mla_benchmark as run_mla
|
|
|
|
return run_mla(
|
|
config.backend, config, prefill_backend=config.prefill_backend, **kwargs
|
|
)
|
|
|
|
|
|
def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
|
|
"""
|
|
Run a single benchmark with proper backend selection.
|
|
|
|
Args:
|
|
config: BenchmarkConfig with backend, batch_spec, and model params
|
|
**kwargs: Additional arguments passed to MLA benchmarks
|
|
|
|
Returns:
|
|
BenchmarkResult (may have error field set on failure)
|
|
"""
|
|
try:
|
|
if is_mla_backend(config.backend):
|
|
return run_mla_benchmark(config, **kwargs)
|
|
else:
|
|
return run_standard_attention_benchmark(config)
|
|
except Exception as e:
|
|
return BenchmarkResult(
|
|
config=config,
|
|
mean_time=float("inf"),
|
|
std_time=0,
|
|
min_time=float("inf"),
|
|
max_time=float("inf"),
|
|
error=str(e),
|
|
)
|
|
|
|
|
|
def run_model_parameter_sweep(
|
|
backends: list[str],
|
|
batch_specs: list[str],
|
|
base_config_args: dict,
|
|
sweep: ModelParameterSweep,
|
|
console: Console,
|
|
) -> list[BenchmarkResult]:
|
|
"""
|
|
Run model parameter sweep for given backends and batch specs.
|
|
|
|
Args:
|
|
backends: List of backend names
|
|
batch_specs: List of batch specifications
|
|
base_config_args: Base configuration arguments (num_layers, head_dim, etc.)
|
|
sweep: ModelParameterSweep configuration
|
|
console: Rich console for output
|
|
|
|
Returns:
|
|
List of BenchmarkResult objects
|
|
"""
|
|
all_results = []
|
|
|
|
console.print(
|
|
f"[yellow]Model sweep mode: testing {sweep.param_name} = {sweep.values}[/]"
|
|
)
|
|
|
|
total = len(backends) * len(batch_specs) * len(sweep.values)
|
|
|
|
with tqdm(total=total, desc="Benchmarking") as pbar:
|
|
for backend in backends:
|
|
for spec in batch_specs:
|
|
for value in sweep.values:
|
|
# Create config with modified model parameter
|
|
config_args = base_config_args.copy()
|
|
config_args[sweep.param_name] = value
|
|
|
|
# Create config with original backend for running
|
|
clean_config = BenchmarkConfig(
|
|
backend=backend, batch_spec=spec, **config_args
|
|
)
|
|
|
|
# Run benchmark
|
|
result = run_benchmark(clean_config)
|
|
|
|
# Replace backend with labeled version for display
|
|
backend_label = sweep.get_label(backend, value)
|
|
labeled_config = replace(result.config, backend=backend_label)
|
|
result = replace(result, config=labeled_config)
|
|
all_results.append(result)
|
|
|
|
if not result.success:
|
|
console.print(
|
|
f"[red]Error {backend} {spec} {sweep.param_name}="
|
|
f"{value}: {result.error}[/]"
|
|
)
|
|
|
|
pbar.update(1)
|
|
|
|
# Display sweep results - create separate table for each parameter value
|
|
console.print("\n[bold green]Model Parameter Sweep Results:[/]")
|
|
formatter = ResultsFormatter(console)
|
|
|
|
# Group results by parameter value and extract backend mapping
|
|
by_param_value = {}
|
|
backend_mapping = {} # Maps labeled backend -> original backend
|
|
|
|
for r in all_results:
|
|
# Extract original backend and param value from labeled backend
|
|
# The label format is: {backend}_{param_name}_{value}
|
|
# We need to reverse engineer this
|
|
labeled_backend = r.config.backend
|
|
|
|
# Try each backend to find which one this result belongs to
|
|
for backend in backends:
|
|
for value in sweep.values:
|
|
expected_label = sweep.get_label(backend, value)
|
|
if labeled_backend == expected_label:
|
|
backend_mapping[labeled_backend] = backend
|
|
param_value = str(value)
|
|
|
|
if param_value not in by_param_value:
|
|
by_param_value[param_value] = []
|
|
by_param_value[param_value].append(r)
|
|
break
|
|
|
|
# Create a table for each parameter value
|
|
sorted_param_values = sorted(
|
|
by_param_value.keys(), key=lambda x: int(x) if x.isdigit() else x
|
|
)
|
|
|
|
for param_value in sorted_param_values:
|
|
console.print(f"\n[bold cyan]{sweep.param_name} = {param_value}[/]")
|
|
param_results = by_param_value[param_value]
|
|
|
|
# Create modified results with original backend names
|
|
modified_results = []
|
|
for r in param_results:
|
|
# Get the original backend name from our mapping
|
|
original_backend = backend_mapping[r.config.backend]
|
|
modified_config = replace(r.config, backend=original_backend)
|
|
modified_result = replace(r, config=modified_config)
|
|
modified_results.append(modified_result)
|
|
|
|
# Print table with original backend names
|
|
formatter.print_table(modified_results, backends, compare_to_fastest=True)
|
|
|
|
# Show optimal backend for each (param_value, batch_spec) combination
|
|
console.print(
|
|
f"\n[bold cyan]Optimal backend for each ({sweep.param_name}, batch_spec):[/]"
|
|
)
|
|
|
|
# Group by (param_value, batch_spec)
|
|
by_param_and_spec = {}
|
|
for r in all_results:
|
|
if r.success:
|
|
# Find which (backend, value) this result corresponds to
|
|
labeled_backend = r.config.backend
|
|
for backend in backends:
|
|
for value in sweep.values:
|
|
expected_label = sweep.get_label(backend, value)
|
|
if labeled_backend == expected_label:
|
|
param_value = str(value)
|
|
spec = r.config.batch_spec
|
|
key = (param_value, spec)
|
|
|
|
if key not in by_param_and_spec:
|
|
by_param_and_spec[key] = []
|
|
by_param_and_spec[key].append(r)
|
|
break
|
|
|
|
# Sort by param value then spec (batch_size, q_len, kv_len)
|
|
sorted_keys = sorted(
|
|
by_param_and_spec.keys(),
|
|
key=lambda x: (
|
|
int(x[0]) if x[0].isdigit() else x[0],
|
|
batch_spec_sort_key(x[1]),
|
|
),
|
|
)
|
|
|
|
current_param_value = None
|
|
for param_value, spec in sorted_keys:
|
|
# Print header when param value changes
|
|
if param_value != current_param_value:
|
|
console.print(f"\n [bold]{sweep.param_name}={param_value}:[/]")
|
|
current_param_value = param_value
|
|
|
|
results = by_param_and_spec[(param_value, spec)]
|
|
best = min(results, key=lambda r: r.mean_time)
|
|
|
|
# Extract original backend name using the mapping
|
|
backend_name = backend_mapping[best.config.backend]
|
|
|
|
# Show all backends' times for comparison
|
|
times_str = " | ".join(
|
|
[
|
|
f"{backend_mapping[r.config.backend]}: {r.mean_time:.6f}s"
|
|
for r in sorted(results, key=lambda r: r.mean_time)
|
|
]
|
|
)
|
|
|
|
console.print(
|
|
f" {spec:12s} -> [bold green]{backend_name:15s}[/] ({times_str})"
|
|
)
|
|
|
|
return all_results
|
|
|
|
|
|
def run_parameter_sweep(
|
|
backends: list[str],
|
|
batch_specs: list[str],
|
|
base_config_args: dict,
|
|
sweep: ParameterSweep,
|
|
console: Console,
|
|
) -> list[BenchmarkResult]:
|
|
"""
|
|
Run parameter sweep for given backends and batch specs.
|
|
|
|
Args:
|
|
backends: List of backend names
|
|
batch_specs: List of batch specifications
|
|
base_config_args: Base configuration arguments (num_layers, head_dim, etc.)
|
|
sweep: ParameterSweep configuration
|
|
console: Rich console for output
|
|
|
|
Returns:
|
|
List of BenchmarkResult objects
|
|
"""
|
|
all_results = []
|
|
|
|
# Build list of values to sweep (including auto if requested)
|
|
sweep_values = list(sweep.values)
|
|
if sweep.include_auto:
|
|
sweep_values.append("auto")
|
|
|
|
console.print(f"[yellow]Sweep mode: testing {sweep.param_name} = {sweep_values}[/]")
|
|
|
|
total = len(backends) * len(batch_specs) * len(sweep_values)
|
|
|
|
with tqdm(total=total, desc="Benchmarking") as pbar:
|
|
for backend in backends:
|
|
for spec in batch_specs:
|
|
for value in sweep_values:
|
|
# Create config with original backend for running
|
|
config = BenchmarkConfig(
|
|
backend=backend, batch_spec=spec, **base_config_args
|
|
)
|
|
|
|
# Prepare kwargs for benchmark runner
|
|
kwargs = {}
|
|
if value != "auto":
|
|
kwargs[sweep.param_name] = value
|
|
|
|
# Run benchmark
|
|
result = run_benchmark(config, **kwargs)
|
|
|
|
# Replace backend with labeled version for display
|
|
backend_label = sweep.get_label(backend, value)
|
|
labeled_config = replace(result.config, backend=backend_label)
|
|
result = replace(result, config=labeled_config)
|
|
all_results.append(result)
|
|
|
|
if not result.success:
|
|
console.print(
|
|
f"[red]Error {backend} {spec} {sweep.param_name}="
|
|
f"{value}: {result.error}[/]"
|
|
)
|
|
|
|
pbar.update(1)
|
|
|
|
# Display sweep results
|
|
console.print("\n[bold green]Sweep Results:[/]")
|
|
backend_labels = [sweep.get_label(b, v) for b in backends for v in sweep_values]
|
|
formatter = ResultsFormatter(console)
|
|
formatter.print_table(all_results, backend_labels)
|
|
|
|
# Show optimal values
|
|
console.print(f"\n[bold cyan]Optimal {sweep.param_name} per batch spec:[/]")
|
|
by_spec = {}
|
|
for r in all_results:
|
|
if r.success:
|
|
spec = r.config.batch_spec
|
|
if spec not in by_spec:
|
|
by_spec[spec] = []
|
|
by_spec[spec].append(r)
|
|
|
|
for spec in sorted(by_spec.keys(), key=batch_spec_sort_key):
|
|
results = by_spec[spec]
|
|
best = min(results, key=lambda r: r.mean_time)
|
|
console.print(
|
|
f" {spec}: [bold green]{best.config.backend}[/] ({best.mean_time:.6f}s)"
|
|
)
|
|
|
|
return all_results
|
|
|
|
|
|
def load_config_from_yaml(config_path: str) -> dict:
|
|
"""Load configuration from YAML file."""
|
|
with open(config_path) as f:
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
def generate_batch_specs_from_ranges(ranges: list[dict]) -> list[str]:
|
|
"""
|
|
Generate batch specs from range specifications.
|
|
|
|
Args:
|
|
ranges: List of range specifications, each containing:
|
|
- template: Batch spec template (e.g., "q{q_len}kv1k")
|
|
- q_len: Dict with start, stop, step, end_inclusive (optional)
|
|
- Other parameters can also be ranges
|
|
|
|
Returns:
|
|
List of generated batch spec strings
|
|
|
|
Example:
|
|
ranges = [
|
|
{
|
|
"template": "q{q_len}kv1k",
|
|
"q_len": {
|
|
"start": 1,
|
|
"stop": 16,
|
|
"step": 1,
|
|
"end_inclusive": true # Optional, defaults to true
|
|
}
|
|
}
|
|
]
|
|
Returns: ["q1kv1k", "q2kv1k", ..., "q16kv1k"]
|
|
"""
|
|
all_specs = []
|
|
|
|
for range_spec in ranges:
|
|
template = range_spec.get("template")
|
|
if not template:
|
|
raise ValueError("Range specification must include 'template'")
|
|
|
|
# Extract all range parameters from the spec
|
|
range_params = {}
|
|
for key, value in range_spec.items():
|
|
if key == "template":
|
|
continue
|
|
if isinstance(value, dict) and "start" in value:
|
|
# This is a range specification
|
|
start = value["start"]
|
|
stop = value["stop"]
|
|
step = value.get("step", 1)
|
|
# Check if end should be inclusive (default: True)
|
|
end_inclusive = value.get("end_inclusive", True)
|
|
|
|
# Adjust stop based on end_inclusive
|
|
if end_inclusive:
|
|
range_params[key] = list(range(start, stop + 1, step))
|
|
else:
|
|
range_params[key] = list(range(start, stop, step))
|
|
else:
|
|
# This is a fixed value
|
|
range_params[key] = [value]
|
|
|
|
# Generate all combinations (Cartesian product)
|
|
if range_params:
|
|
import itertools
|
|
|
|
param_names = list(range_params.keys())
|
|
param_values = [range_params[name] for name in param_names]
|
|
|
|
for values in itertools.product(*param_values):
|
|
params = dict(zip(param_names, values))
|
|
spec = template.format(**params)
|
|
all_specs.append(spec)
|
|
else:
|
|
# No parameters, just use template as-is
|
|
all_specs.append(template)
|
|
|
|
return all_specs
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Universal vLLM attention benchmark",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=__doc__,
|
|
)
|
|
|
|
# Config file
|
|
parser.add_argument(
|
|
"--config",
|
|
help="Path to YAML config file (overrides other args)",
|
|
)
|
|
|
|
# Backend selection
|
|
parser.add_argument(
|
|
"--backends",
|
|
"--decode-backends",
|
|
nargs="+",
|
|
help="Decode backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
|
|
"flashinfer_mla, flashattn_mla, flashmla)",
|
|
)
|
|
parser.add_argument(
|
|
"--backend",
|
|
help="Single backend (alternative to --backends)",
|
|
)
|
|
parser.add_argument(
|
|
"--prefill-backends",
|
|
nargs="+",
|
|
help="Prefill backends to compare (fa2, fa3, fa4). "
|
|
"Uses the first decode backend for impl construction.",
|
|
)
|
|
|
|
# Batch specifications
|
|
parser.add_argument(
|
|
"--batch-specs",
|
|
nargs="+",
|
|
default=None,
|
|
help="Batch specifications using extended grammar",
|
|
)
|
|
|
|
# Model config
|
|
parser.add_argument("--num-layers", type=int, default=10, help="Number of layers")
|
|
parser.add_argument("--head-dim", type=int, default=128, help="Head dimension")
|
|
parser.add_argument("--num-q-heads", type=int, default=32, help="Query heads")
|
|
parser.add_argument("--num-kv-heads", type=int, default=8, help="KV heads")
|
|
parser.add_argument("--block-size", type=int, default=16, help="Block size")
|
|
|
|
# Benchmark settings
|
|
parser.add_argument("--device", default="cuda:0", help="Device")
|
|
parser.add_argument("--repeats", type=int, default=1, help="Repetitions")
|
|
parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations")
|
|
parser.add_argument("--profile-memory", action="store_true", help="Profile memory")
|
|
parser.add_argument(
|
|
"--kv-cache-dtype",
|
|
default="auto",
|
|
choices=["auto", "fp8"],
|
|
help="KV cache dtype: auto or fp8",
|
|
)
|
|
parser.add_argument(
|
|
"--cuda-graphs",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=True,
|
|
help=(
|
|
"Launch kernels with CUDA graphs to eliminate CPU overhead"
|
|
"in measurements (default: True)"
|
|
),
|
|
)
|
|
|
|
# Parameter sweep (use YAML config for advanced sweeps)
|
|
parser.add_argument(
|
|
"--sweep-param",
|
|
help="Parameter name to sweep (e.g., num_kv_splits, reorder_batch_threshold)",
|
|
)
|
|
parser.add_argument(
|
|
"--sweep-values",
|
|
type=int,
|
|
nargs="+",
|
|
help="Values to sweep for the parameter",
|
|
)
|
|
|
|
# Output
|
|
parser.add_argument("--output-csv", help="Save to CSV")
|
|
parser.add_argument("--output-json", help="Save to JSON")
|
|
|
|
args = parser.parse_args()
|
|
|
|
console = Console()
|
|
console.print("[bold cyan]vLLM Attention Benchmark[/]")
|
|
|
|
# Load config from YAML if provided
|
|
if args.config:
|
|
console.print(f"[yellow]Loading config from: {args.config}[/]")
|
|
yaml_config = load_config_from_yaml(args.config)
|
|
|
|
# Show description if available
|
|
if "description" in yaml_config:
|
|
console.print(f"[dim]{yaml_config['description']}[/]")
|
|
|
|
# Override args with YAML values, but CLI args take precedence
|
|
# Check if CLI provided backends (they would be non-None and not default)
|
|
cli_backends_provided = args.backend is not None or args.backends is not None
|
|
|
|
# Backend(s) - only use YAML if CLI didn't specify
|
|
if not cli_backends_provided:
|
|
if "backend" in yaml_config:
|
|
args.backend = yaml_config["backend"]
|
|
args.backends = None
|
|
elif "backends" in yaml_config:
|
|
args.backends = yaml_config["backends"]
|
|
args.backend = None
|
|
elif "decode_backends" in yaml_config:
|
|
args.backends = yaml_config["decode_backends"]
|
|
args.backend = None
|
|
|
|
# Prefill backends (e.g., ["fa3", "fa4"])
|
|
args.prefill_backends = yaml_config.get("prefill_backends", None)
|
|
|
|
# Check for special modes
|
|
args.mode = yaml_config.get("mode", None)
|
|
|
|
# Batch specs and sizes
|
|
# Support both explicit batch_specs and generated batch_spec_ranges
|
|
# CLI --batch-specs takes precedence over YAML when provided.
|
|
cli_batch_specs_provided = args.batch_specs is not None
|
|
if not cli_batch_specs_provided:
|
|
if "batch_spec_ranges" in yaml_config:
|
|
# Generate batch specs from ranges
|
|
generated_specs = generate_batch_specs_from_ranges(
|
|
yaml_config["batch_spec_ranges"]
|
|
)
|
|
# Combine with any explicit batch_specs
|
|
if "batch_specs" in yaml_config:
|
|
args.batch_specs = yaml_config["batch_specs"] + generated_specs
|
|
else:
|
|
args.batch_specs = generated_specs
|
|
console.print(
|
|
f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]"
|
|
)
|
|
elif "batch_specs" in yaml_config:
|
|
args.batch_specs = yaml_config["batch_specs"]
|
|
|
|
args.batch_sizes = yaml_config.get("batch_sizes", None)
|
|
|
|
# Model config
|
|
if "model" in yaml_config:
|
|
model = yaml_config["model"]
|
|
args.num_layers = model.get("num_layers", args.num_layers)
|
|
args.head_dim = model.get("head_dim", args.head_dim)
|
|
args.num_q_heads = model.get("num_q_heads", args.num_q_heads)
|
|
args.num_kv_heads = model.get("num_kv_heads", args.num_kv_heads)
|
|
args.block_size = model.get("block_size", args.block_size)
|
|
|
|
# Benchmark settings (top-level keys)
|
|
if "device" in yaml_config:
|
|
args.device = yaml_config["device"]
|
|
if "repeats" in yaml_config:
|
|
args.repeats = yaml_config["repeats"]
|
|
if "warmup_iters" in yaml_config:
|
|
args.warmup_iters = yaml_config["warmup_iters"]
|
|
if "profile_memory" in yaml_config:
|
|
args.profile_memory = yaml_config["profile_memory"]
|
|
if "kv_cache_dtype" in yaml_config:
|
|
args.kv_cache_dtype = yaml_config["kv_cache_dtype"]
|
|
if "cuda_graphs" in yaml_config:
|
|
args.cuda_graphs = yaml_config["cuda_graphs"]
|
|
|
|
# Parameter sweep configuration
|
|
if "parameter_sweep" in yaml_config:
|
|
sweep_config = yaml_config["parameter_sweep"]
|
|
args.parameter_sweep = ParameterSweep(
|
|
param_name=sweep_config["param_name"],
|
|
values=sweep_config["values"],
|
|
include_auto=sweep_config.get("include_auto", False),
|
|
label_format=sweep_config.get(
|
|
"label_format", "{backend}_{param_name}_{value}"
|
|
),
|
|
)
|
|
else:
|
|
args.parameter_sweep = None
|
|
|
|
# Model parameter sweep configuration
|
|
if "model_parameter_sweep" in yaml_config:
|
|
sweep_config = yaml_config["model_parameter_sweep"]
|
|
args.model_parameter_sweep = ModelParameterSweep(
|
|
param_name=sweep_config["param_name"],
|
|
values=sweep_config["values"],
|
|
label_format=sweep_config.get(
|
|
"label_format", "{backend}_{param_name}_{value}"
|
|
),
|
|
)
|
|
else:
|
|
args.model_parameter_sweep = None
|
|
|
|
# Output
|
|
if "output" in yaml_config:
|
|
output = yaml_config["output"]
|
|
if "csv" in output and not args.output_csv:
|
|
args.output_csv = output["csv"]
|
|
if "json" in output and not args.output_json:
|
|
args.output_json = output["json"]
|
|
|
|
console.print()
|
|
|
|
# Handle CLI-based parameter sweep (if not from YAML)
|
|
if (
|
|
(not hasattr(args, "parameter_sweep") or args.parameter_sweep is None)
|
|
and args.sweep_param
|
|
and args.sweep_values
|
|
):
|
|
args.parameter_sweep = ParameterSweep(
|
|
param_name=args.sweep_param,
|
|
values=args.sweep_values,
|
|
include_auto=False,
|
|
label_format="{backend}_{param_name}_{value}",
|
|
)
|
|
|
|
# Determine backends
|
|
backends = args.backends or ([args.backend] if args.backend else ["flash"])
|
|
prefill_backends = getattr(args, "prefill_backends", None)
|
|
if not args.batch_specs:
|
|
args.batch_specs = ["q2k", "8q1s1k"]
|
|
console.print(f"Backends: {', '.join(backends)}")
|
|
if prefill_backends:
|
|
console.print(f"Prefill backends: {', '.join(prefill_backends)}")
|
|
console.print(f"Batch specs: {', '.join(args.batch_specs)}")
|
|
console.print(f"KV cache dtype: {args.kv_cache_dtype}")
|
|
console.print(f"CUDA graphs: {args.cuda_graphs}")
|
|
console.print()
|
|
|
|
init_workspace_manager(args.device)
|
|
|
|
# Run benchmarks
|
|
all_results = []
|
|
|
|
# Handle special mode: decode_vs_prefill comparison
|
|
if hasattr(args, "mode") and args.mode == "decode_vs_prefill":
|
|
console.print("[yellow]Mode: Decode vs Prefill pipeline comparison[/]")
|
|
console.print(
|
|
"[dim]For each query length, testing both decode and prefill pipelines[/]"
|
|
)
|
|
console.print("[dim]Using batched execution for optimal performance[/]")
|
|
|
|
# Extract batch sizes from config
|
|
batch_sizes = getattr(args, "batch_sizes", [1])
|
|
backend = backends[0] # Use first backend (should only be one)
|
|
|
|
# Calculate total benchmarks
|
|
total = len(batch_sizes)
|
|
|
|
with tqdm(total=total, desc="Benchmarking") as pbar:
|
|
for batch_size in batch_sizes:
|
|
# Prepare all configs for this batch size
|
|
configs_with_thresholds = []
|
|
|
|
for spec in args.batch_specs:
|
|
# Parse the batch spec to get query length
|
|
requests = parse_batch_spec(spec)
|
|
if not requests:
|
|
console.print(
|
|
f"[red]Error: Could not parse batch spec '{spec}'[/]"
|
|
)
|
|
continue
|
|
|
|
# Get query length from first request
|
|
query_length = requests[0].q_len
|
|
|
|
# Create batch spec for this batch size
|
|
# For batch_size > 1, we need to prepend the count
|
|
batch_spec = f"{batch_size}{spec}" if batch_size > 1 else spec
|
|
|
|
# Create base config (without backend name)
|
|
base_config = BenchmarkConfig(
|
|
backend=backend, # Will be overridden later
|
|
batch_spec=batch_spec,
|
|
num_layers=args.num_layers,
|
|
head_dim=args.head_dim,
|
|
num_q_heads=args.num_q_heads,
|
|
num_kv_heads=args.num_kv_heads,
|
|
block_size=args.block_size,
|
|
device=args.device,
|
|
repeats=args.repeats,
|
|
warmup_iters=args.warmup_iters,
|
|
profile_memory=args.profile_memory,
|
|
kv_cache_dtype=args.kv_cache_dtype,
|
|
use_cuda_graphs=args.cuda_graphs,
|
|
)
|
|
|
|
# Add decode pipeline config
|
|
decode_threshold = query_length
|
|
config_decode = replace(
|
|
base_config,
|
|
backend=f"{backend}_decode_qlen{query_length}_bs{batch_size}",
|
|
)
|
|
configs_with_thresholds.append((config_decode, decode_threshold))
|
|
|
|
# Add prefill pipeline config if query_length > 1
|
|
if query_length > 1:
|
|
prefill_threshold = query_length - 1
|
|
config_prefill = replace(
|
|
base_config,
|
|
backend=f"{backend}_prefill_qlen{query_length}"
|
|
f"_bs{batch_size}",
|
|
)
|
|
configs_with_thresholds.append(
|
|
(config_prefill, prefill_threshold)
|
|
)
|
|
|
|
# Run all benchmarks for this batch size in one go (batched mode)
|
|
try:
|
|
from mla_runner import run_mla_benchmark as run_mla
|
|
|
|
# Use batched API: pass list of (config, threshold) tuples
|
|
timing_results = run_mla(backend, configs_with_thresholds)
|
|
|
|
# Create BenchmarkResult objects from timing results
|
|
for (config, _), timing in zip(
|
|
configs_with_thresholds, timing_results
|
|
):
|
|
result = BenchmarkResult(
|
|
config=config,
|
|
mean_time=timing["mean"],
|
|
std_time=timing["std"],
|
|
min_time=timing["min"],
|
|
max_time=timing["max"],
|
|
throughput_tokens_per_sec=timing.get("throughput", None),
|
|
)
|
|
all_results.append(result)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
console.print(
|
|
f"[red]Error running batched benchmarks for "
|
|
f"batch_size={batch_size}: {e}[/]"
|
|
)
|
|
console.print("[red]Traceback:[/]")
|
|
traceback.print_exc()
|
|
# Add error results for all configs
|
|
for config, _ in configs_with_thresholds:
|
|
result = BenchmarkResult(
|
|
config=config,
|
|
mean_time=float("inf"),
|
|
std_time=0,
|
|
min_time=float("inf"),
|
|
max_time=float("inf"),
|
|
error=str(e),
|
|
)
|
|
all_results.append(result)
|
|
|
|
pbar.update(1)
|
|
|
|
# Display decode vs prefill results
|
|
console.print("\n[bold green]Decode vs Prefill Results:[/]")
|
|
|
|
# Group by batch size
|
|
by_batch_size = {}
|
|
for r in all_results:
|
|
if r.success:
|
|
# Extract batch size from backend name
|
|
parts = r.config.backend.split("_")
|
|
bs_part = [p for p in parts if p.startswith("bs")]
|
|
if bs_part:
|
|
bs = int(bs_part[0][2:])
|
|
if bs not in by_batch_size:
|
|
by_batch_size[bs] = []
|
|
by_batch_size[bs].append(r)
|
|
|
|
# For each batch size, analyze crossover point
|
|
for bs in sorted(by_batch_size.keys()):
|
|
console.print(f"\n[bold cyan]Batch size: {bs}[/]")
|
|
results = by_batch_size[bs]
|
|
|
|
# Group by query length
|
|
by_qlen = {}
|
|
for r in results:
|
|
parts = r.config.backend.split("_")
|
|
qlen_part = [p for p in parts if p.startswith("qlen")]
|
|
if qlen_part:
|
|
qlen = int(qlen_part[0][4:])
|
|
if qlen not in by_qlen:
|
|
by_qlen[qlen] = {}
|
|
|
|
pipeline = "decode" if "decode" in r.config.backend else "prefill"
|
|
by_qlen[qlen][pipeline] = r
|
|
|
|
# Find crossover point
|
|
last_decode_faster = None
|
|
for qlen in sorted(by_qlen.keys()):
|
|
pipelines = by_qlen[qlen]
|
|
if "decode" in pipelines and "prefill" in pipelines:
|
|
decode_time = pipelines["decode"].mean_time
|
|
prefill_time = pipelines["prefill"].mean_time
|
|
faster = "decode" if decode_time < prefill_time else "prefill"
|
|
|
|
speedup = (
|
|
prefill_time / decode_time
|
|
if decode_time < prefill_time
|
|
else decode_time / prefill_time
|
|
)
|
|
|
|
console.print(
|
|
f" qlen={qlen:3d}: decode={decode_time:.6f}s, "
|
|
f"prefill={prefill_time:.6f}s -> "
|
|
f"[bold]{faster}[/] ({speedup:.2f}x)"
|
|
)
|
|
|
|
if faster == "decode":
|
|
last_decode_faster = qlen
|
|
|
|
if last_decode_faster is not None:
|
|
optimal_threshold = last_decode_faster
|
|
console.print(
|
|
f"\n [bold green]Optimal threshold for batch_size={bs}: "
|
|
f"{optimal_threshold}[/]"
|
|
)
|
|
console.print(
|
|
f" [dim](Use decode pipeline for query_length <= "
|
|
f"{optimal_threshold})[/]"
|
|
)
|
|
else:
|
|
console.print(
|
|
f"\n [yellow]Prefill always faster for batch_size={bs}[/]"
|
|
)
|
|
|
|
# Handle model parameter sweep mode
|
|
elif hasattr(args, "model_parameter_sweep") and args.model_parameter_sweep:
|
|
# Model parameter sweep
|
|
base_config_args = {
|
|
"num_layers": args.num_layers,
|
|
"head_dim": args.head_dim,
|
|
"num_q_heads": args.num_q_heads,
|
|
"num_kv_heads": args.num_kv_heads,
|
|
"block_size": args.block_size,
|
|
"device": args.device,
|
|
"repeats": args.repeats,
|
|
"warmup_iters": args.warmup_iters,
|
|
"profile_memory": args.profile_memory,
|
|
"kv_cache_dtype": args.kv_cache_dtype,
|
|
"use_cuda_graphs": args.cuda_graphs,
|
|
}
|
|
all_results = run_model_parameter_sweep(
|
|
backends,
|
|
args.batch_specs,
|
|
base_config_args,
|
|
args.model_parameter_sweep,
|
|
console,
|
|
)
|
|
|
|
# Handle parameter sweep mode (unified)
|
|
elif hasattr(args, "parameter_sweep") and args.parameter_sweep:
|
|
# Unified parameter sweep
|
|
base_config_args = {
|
|
"num_layers": args.num_layers,
|
|
"head_dim": args.head_dim,
|
|
"num_q_heads": args.num_q_heads,
|
|
"num_kv_heads": args.num_kv_heads,
|
|
"block_size": args.block_size,
|
|
"device": args.device,
|
|
"repeats": args.repeats,
|
|
"warmup_iters": args.warmup_iters,
|
|
"profile_memory": args.profile_memory,
|
|
"kv_cache_dtype": args.kv_cache_dtype,
|
|
"use_cuda_graphs": args.cuda_graphs,
|
|
}
|
|
all_results = run_parameter_sweep(
|
|
backends, args.batch_specs, base_config_args, args.parameter_sweep, console
|
|
)
|
|
|
|
else:
|
|
# Normal mode: compare backends
|
|
decode_results = []
|
|
prefill_results = []
|
|
|
|
# Run decode backend comparison
|
|
if not prefill_backends:
|
|
# No prefill backends specified: compare decode backends as before
|
|
total = len(backends) * len(args.batch_specs)
|
|
|
|
with tqdm(total=total, desc="Benchmarking") as pbar:
|
|
for spec in args.batch_specs:
|
|
for backend in backends:
|
|
config = BenchmarkConfig(
|
|
backend=backend,
|
|
batch_spec=spec,
|
|
num_layers=args.num_layers,
|
|
head_dim=args.head_dim,
|
|
num_q_heads=args.num_q_heads,
|
|
num_kv_heads=args.num_kv_heads,
|
|
block_size=args.block_size,
|
|
device=args.device,
|
|
repeats=args.repeats,
|
|
warmup_iters=args.warmup_iters,
|
|
profile_memory=args.profile_memory,
|
|
kv_cache_dtype=args.kv_cache_dtype,
|
|
use_cuda_graphs=args.cuda_graphs,
|
|
)
|
|
|
|
result = run_benchmark(config)
|
|
decode_results.append(result)
|
|
|
|
if not result.success:
|
|
console.print(
|
|
f"[red]Error {backend} {spec}: {result.error}[/]"
|
|
)
|
|
|
|
pbar.update(1)
|
|
|
|
console.print("\n[bold green]Results:[/]")
|
|
formatter = ResultsFormatter(console)
|
|
formatter.print_table(decode_results, backends)
|
|
|
|
# Run prefill backend comparison
|
|
if prefill_backends:
|
|
# Use first decode backend for impl construction
|
|
decode_backend = backends[0]
|
|
total = len(prefill_backends) * len(args.batch_specs)
|
|
|
|
console.print(
|
|
f"[yellow]Prefill comparison mode: "
|
|
f"using {decode_backend} for decode impl[/]"
|
|
)
|
|
|
|
with tqdm(total=total, desc="Prefill benchmarking") as pbar:
|
|
for spec in args.batch_specs:
|
|
for pb in prefill_backends:
|
|
config = BenchmarkConfig(
|
|
backend=decode_backend,
|
|
batch_spec=spec,
|
|
num_layers=args.num_layers,
|
|
head_dim=args.head_dim,
|
|
num_q_heads=args.num_q_heads,
|
|
num_kv_heads=args.num_kv_heads,
|
|
block_size=args.block_size,
|
|
device=args.device,
|
|
repeats=args.repeats,
|
|
warmup_iters=args.warmup_iters,
|
|
profile_memory=args.profile_memory,
|
|
prefill_backend=pb,
|
|
)
|
|
|
|
result = run_benchmark(config)
|
|
|
|
# Label result with prefill backend name for display
|
|
labeled_config = replace(result.config, backend=pb)
|
|
result = replace(result, config=labeled_config)
|
|
prefill_results.append(result)
|
|
|
|
if not result.success:
|
|
console.print(f"[red]Error {pb} {spec}: {result.error}[/]")
|
|
|
|
pbar.update(1)
|
|
|
|
console.print("\n[bold green]Prefill Backend Results:[/]")
|
|
formatter = ResultsFormatter(console)
|
|
formatter.print_table(
|
|
prefill_results, prefill_backends, compare_to_fastest=True
|
|
)
|
|
|
|
all_results = decode_results + prefill_results
|
|
|
|
# Save results
|
|
if all_results:
|
|
formatter = ResultsFormatter(console)
|
|
if args.output_csv:
|
|
formatter.save_csv(all_results, args.output_csv)
|
|
if args.output_json:
|
|
formatter.save_json(all_results, args.output_json)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|