[Attention] Use FA4 for MLA prefill (#34732)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -59,7 +59,9 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
|
||||
"""Run MLA benchmark with appropriate backend."""
|
||||
from mla_runner import run_mla_benchmark as run_mla
|
||||
|
||||
return run_mla(config.backend, config, **kwargs)
|
||||
return run_mla(
|
||||
config.backend, config, prefill_backend=config.prefill_backend, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
|
||||
@@ -440,14 +442,21 @@ def main():
|
||||
# Backend selection
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
"--decode-backends",
|
||||
nargs="+",
|
||||
help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
|
||||
help="Decode backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
|
||||
"flashinfer_mla, flashattn_mla, flashmla)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
help="Single backend (alternative to --backends)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-backends",
|
||||
nargs="+",
|
||||
help="Prefill backends to compare (fa2, fa3, fa4). "
|
||||
"Uses the first decode backend for impl construction.",
|
||||
)
|
||||
|
||||
# Batch specifications
|
||||
parser.add_argument(
|
||||
@@ -502,7 +511,7 @@ def main():
|
||||
|
||||
# Override args with YAML values, but CLI args take precedence
|
||||
# Check if CLI provided backends (they would be non-None and not default)
|
||||
cli_backends_provided = args.backends is not None or args.backend is not None
|
||||
cli_backends_provided = args.backend is not None or args.backends is not None
|
||||
|
||||
# Backend(s) - only use YAML if CLI didn't specify
|
||||
if not cli_backends_provided:
|
||||
@@ -512,6 +521,12 @@ def main():
|
||||
elif "backends" in yaml_config:
|
||||
args.backends = yaml_config["backends"]
|
||||
args.backend = None
|
||||
elif "decode_backends" in yaml_config:
|
||||
args.backends = yaml_config["decode_backends"]
|
||||
args.backend = None
|
||||
|
||||
# Prefill backends (e.g., ["fa3", "fa4"])
|
||||
args.prefill_backends = yaml_config.get("prefill_backends", None)
|
||||
|
||||
# Check for special modes
|
||||
if "mode" in yaml_config:
|
||||
@@ -613,7 +628,10 @@ def main():
|
||||
|
||||
# Determine backends
|
||||
backends = args.backends or ([args.backend] if args.backend else ["flash"])
|
||||
prefill_backends = getattr(args, "prefill_backends", None)
|
||||
console.print(f"Backends: {', '.join(backends)}")
|
||||
if prefill_backends:
|
||||
console.print(f"Prefill backends: {', '.join(prefill_backends)}")
|
||||
console.print(f"Batch specs: {', '.join(args.batch_specs)}")
|
||||
console.print()
|
||||
|
||||
@@ -850,37 +868,93 @@ def main():
|
||||
|
||||
else:
|
||||
# Normal mode: compare backends
|
||||
total = len(backends) * len(args.batch_specs)
|
||||
decode_results = []
|
||||
prefill_results = []
|
||||
|
||||
with tqdm(total=total, desc="Benchmarking") as pbar:
|
||||
for spec in args.batch_specs:
|
||||
for backend in backends:
|
||||
config = BenchmarkConfig(
|
||||
backend=backend,
|
||||
batch_spec=spec,
|
||||
num_layers=args.num_layers,
|
||||
head_dim=args.head_dim,
|
||||
num_q_heads=args.num_q_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
block_size=args.block_size,
|
||||
device=args.device,
|
||||
repeats=args.repeats,
|
||||
warmup_iters=args.warmup_iters,
|
||||
profile_memory=args.profile_memory,
|
||||
)
|
||||
# Run decode backend comparison
|
||||
if not prefill_backends:
|
||||
# No prefill backends specified: compare decode backends as before
|
||||
total = len(backends) * len(args.batch_specs)
|
||||
|
||||
result = run_benchmark(config)
|
||||
all_results.append(result)
|
||||
with tqdm(total=total, desc="Benchmarking") as pbar:
|
||||
for spec in args.batch_specs:
|
||||
for backend in backends:
|
||||
config = BenchmarkConfig(
|
||||
backend=backend,
|
||||
batch_spec=spec,
|
||||
num_layers=args.num_layers,
|
||||
head_dim=args.head_dim,
|
||||
num_q_heads=args.num_q_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
block_size=args.block_size,
|
||||
device=args.device,
|
||||
repeats=args.repeats,
|
||||
warmup_iters=args.warmup_iters,
|
||||
profile_memory=args.profile_memory,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
console.print(f"[red]Error {backend} {spec}: {result.error}[/]")
|
||||
result = run_benchmark(config)
|
||||
decode_results.append(result)
|
||||
|
||||
pbar.update(1)
|
||||
if not result.success:
|
||||
console.print(
|
||||
f"[red]Error {backend} {spec}: {result.error}[/]"
|
||||
)
|
||||
|
||||
# Display results
|
||||
console.print("\n[bold green]Results:[/]")
|
||||
formatter = ResultsFormatter(console)
|
||||
formatter.print_table(all_results, backends)
|
||||
pbar.update(1)
|
||||
|
||||
console.print("\n[bold green]Results:[/]")
|
||||
formatter = ResultsFormatter(console)
|
||||
formatter.print_table(decode_results, backends)
|
||||
|
||||
# Run prefill backend comparison
|
||||
if prefill_backends:
|
||||
# Use first decode backend for impl construction
|
||||
decode_backend = backends[0]
|
||||
total = len(prefill_backends) * len(args.batch_specs)
|
||||
|
||||
console.print(
|
||||
f"[yellow]Prefill comparison mode: "
|
||||
f"using {decode_backend} for decode impl[/]"
|
||||
)
|
||||
|
||||
with tqdm(total=total, desc="Prefill benchmarking") as pbar:
|
||||
for spec in args.batch_specs:
|
||||
for pb in prefill_backends:
|
||||
config = BenchmarkConfig(
|
||||
backend=decode_backend,
|
||||
batch_spec=spec,
|
||||
num_layers=args.num_layers,
|
||||
head_dim=args.head_dim,
|
||||
num_q_heads=args.num_q_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
block_size=args.block_size,
|
||||
device=args.device,
|
||||
repeats=args.repeats,
|
||||
warmup_iters=args.warmup_iters,
|
||||
profile_memory=args.profile_memory,
|
||||
prefill_backend=pb,
|
||||
)
|
||||
|
||||
result = run_benchmark(config)
|
||||
|
||||
# Label result with prefill backend name for display
|
||||
labeled_config = replace(result.config, backend=pb)
|
||||
result = replace(result, config=labeled_config)
|
||||
prefill_results.append(result)
|
||||
|
||||
if not result.success:
|
||||
console.print(f"[red]Error {pb} {spec}: {result.error}[/]")
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
console.print("\n[bold green]Prefill Backend Results:[/]")
|
||||
formatter = ResultsFormatter(console)
|
||||
formatter.print_table(
|
||||
prefill_results, prefill_backends, compare_to_fastest=True
|
||||
)
|
||||
|
||||
all_results = decode_results + prefill_results
|
||||
|
||||
# Save results
|
||||
if all_results:
|
||||
|
||||
@@ -77,6 +77,7 @@ class MockKVBProj:
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.out_dim = qk_nope_head_dim + v_head_dim
|
||||
self.weight = torch.empty(0, dtype=torch.bfloat16)
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]:
|
||||
"""
|
||||
@@ -213,6 +214,7 @@ class BenchmarkConfig:
|
||||
use_cuda_graphs: bool = False
|
||||
|
||||
# MLA-specific
|
||||
prefill_backend: str | None = None
|
||||
kv_lora_rank: int | None = None
|
||||
qk_nope_head_dim: int | None = None
|
||||
qk_rope_head_dim: int | None = None
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
# MLA prefill-only benchmark configuration for sparse backends
|
||||
# MLA prefill backend comparison
|
||||
#
|
||||
# Compares all available MLA prefill backends:
|
||||
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
|
||||
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
|
||||
#
|
||||
# Uses cutlass_mla as the decode backend for impl construction
|
||||
# (only the prefill path is exercised).
|
||||
#
|
||||
# Backends that aren't available on the current platform will report errors
|
||||
# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory).
|
||||
#
|
||||
# Usage:
|
||||
# python benchmark.py --config configs/mla_prefill.yaml
|
||||
|
||||
description: "MLA prefill backend comparison"
|
||||
|
||||
model:
|
||||
name: "deepseek-v3"
|
||||
@@ -12,20 +27,25 @@ model:
|
||||
v_head_dim: 128
|
||||
block_size: 128
|
||||
|
||||
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
|
||||
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
|
||||
model_parameter_sweep:
|
||||
param_name: "num_q_heads"
|
||||
values: [128, 64, 32, 16]
|
||||
label_format: "{backend}_{value}h"
|
||||
# model:
|
||||
# name: "deepseek-v2-lite"
|
||||
# num_layers: 27
|
||||
# num_q_heads: 16
|
||||
# num_kv_heads: 1
|
||||
# head_dim: 576
|
||||
# kv_lora_rank: 512
|
||||
# qk_nope_head_dim: 128
|
||||
# qk_rope_head_dim: 64
|
||||
# v_head_dim: 128
|
||||
# block_size: 128
|
||||
|
||||
batch_specs:
|
||||
# Pure prefill
|
||||
- "1q512"
|
||||
- "1q1k"
|
||||
- "1q2k"
|
||||
- "1q4k"
|
||||
- "1q8k"
|
||||
- "q512"
|
||||
- "q1k"
|
||||
- "q2k"
|
||||
- "q4k"
|
||||
- "q8k"
|
||||
|
||||
# Batched pure prefill
|
||||
- "2q512"
|
||||
@@ -44,19 +64,63 @@ batch_specs:
|
||||
- "8q4k"
|
||||
- "8q8k"
|
||||
|
||||
# Extend
|
||||
- "1q512s4k"
|
||||
- "1q512s8k"
|
||||
- "1q1ks8k"
|
||||
- "1q2ks8k"
|
||||
- "1q2ks16k"
|
||||
- "1q4ks16k"
|
||||
# Chunked prefill / extend
|
||||
# Short context
|
||||
- "q128s1k"
|
||||
- "q256s2k"
|
||||
- "q512s4k"
|
||||
- "q1ks4k"
|
||||
- "q2ks8k"
|
||||
- "2q128s1k"
|
||||
- "2q256s2k"
|
||||
- "2q512s4k"
|
||||
- "2q1ks4k"
|
||||
- "2q2ks8k"
|
||||
- "4q128s1k"
|
||||
- "4q256s2k"
|
||||
- "4q512s4k"
|
||||
- "4q1ks4k"
|
||||
- "4q2ks8k"
|
||||
- "8q128s1k"
|
||||
- "8q256s2k"
|
||||
- "8q512s4k"
|
||||
- "8q1ks4k"
|
||||
|
||||
backends:
|
||||
- FLASHMLA_SPARSE
|
||||
- FLASHINFER_MLA_SPARSE
|
||||
# Medium context
|
||||
- "q128s16k"
|
||||
- "q512s16k"
|
||||
- "q1ks16k"
|
||||
- "q2ks16k"
|
||||
- "2q128s16k"
|
||||
- "2q512s16k"
|
||||
- "2q1ks16k"
|
||||
- "2q2ks16k"
|
||||
- "4q128s16k"
|
||||
- "4q512s16k"
|
||||
- "4q1ks16k"
|
||||
- "4q2ks16k"
|
||||
|
||||
# Long context
|
||||
- "q128s64k"
|
||||
- "q512s64k"
|
||||
- "q1ks64k"
|
||||
- "q2ks64k"
|
||||
- "2q128s64k"
|
||||
- "2q512s64k"
|
||||
- "2q1ks64k"
|
||||
- "2q2ks64k"
|
||||
|
||||
decode_backends:
|
||||
- CUTLASS_MLA
|
||||
|
||||
prefill_backends:
|
||||
- fa2
|
||||
- fa3
|
||||
- fa4
|
||||
- flashinfer
|
||||
- cudnn
|
||||
- trtllm
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 10
|
||||
warmup_iters: 3
|
||||
profile_memory: true
|
||||
repeats: 20
|
||||
warmup_iters: 5
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
# MLA prefill-only benchmark configuration for sparse backends
|
||||
|
||||
model:
|
||||
name: "deepseek-v3"
|
||||
num_layers: 60
|
||||
num_q_heads: 128
|
||||
num_kv_heads: 1
|
||||
head_dim: 576
|
||||
kv_lora_rank: 512
|
||||
qk_nope_head_dim: 128
|
||||
qk_rope_head_dim: 64
|
||||
v_head_dim: 128
|
||||
block_size: 128
|
||||
|
||||
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
|
||||
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
|
||||
model_parameter_sweep:
|
||||
param_name: "num_q_heads"
|
||||
values: [128, 64, 32, 16]
|
||||
label_format: "{backend}_{value}h"
|
||||
|
||||
batch_specs:
|
||||
# Pure prefill
|
||||
- "1q512"
|
||||
- "1q1k"
|
||||
- "1q2k"
|
||||
- "1q4k"
|
||||
- "1q8k"
|
||||
|
||||
# Batched pure prefill
|
||||
- "2q512"
|
||||
- "2q1k"
|
||||
- "2q2k"
|
||||
- "2q4k"
|
||||
- "2q8k"
|
||||
- "4q512"
|
||||
- "4q1k"
|
||||
- "4q2k"
|
||||
- "4q4k"
|
||||
- "4q8k"
|
||||
- "8q512"
|
||||
- "8q1k"
|
||||
- "8q2k"
|
||||
- "8q4k"
|
||||
- "8q8k"
|
||||
|
||||
# Extend
|
||||
- "1q512s4k"
|
||||
- "1q512s8k"
|
||||
- "1q1ks8k"
|
||||
- "1q2ks8k"
|
||||
- "1q2ks16k"
|
||||
- "1q4ks16k"
|
||||
|
||||
backends:
|
||||
- FLASHMLA_SPARSE
|
||||
- FLASHINFER_MLA_SPARSE
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 10
|
||||
warmup_iters: 3
|
||||
profile_memory: true
|
||||
@@ -62,6 +62,7 @@ def create_minimal_vllm_config(
|
||||
max_num_seqs: int = 256,
|
||||
mla_dims: dict | None = None,
|
||||
index_topk: int | None = None,
|
||||
prefill_backend: str | None = None,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create minimal VllmConfig for MLA benchmarks.
|
||||
@@ -75,6 +76,9 @@ def create_minimal_vllm_config(
|
||||
setup_mla_dims(model_name)
|
||||
index_topk: Optional topk value for sparse MLA backends. If provided,
|
||||
the config will include index_topk for sparse attention.
|
||||
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
|
||||
"cudnn", "trtllm"). Configures the attention config to
|
||||
force the specified prefill backend.
|
||||
|
||||
Returns:
|
||||
VllmConfig for benchmarking
|
||||
@@ -163,7 +167,7 @@ def create_minimal_vllm_config(
|
||||
|
||||
compilation_config = CompilationConfig()
|
||||
|
||||
return VllmConfig(
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
@@ -171,9 +175,84 @@ def create_minimal_vllm_config(
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
|
||||
if prefill_backend is not None:
|
||||
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
||||
if prefill_cfg["flash_attn_version"] is not None:
|
||||
vllm_config.attention_config.flash_attn_version = prefill_cfg[
|
||||
"flash_attn_version"
|
||||
]
|
||||
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
|
||||
"disable_flashinfer_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
|
||||
"use_cudnn_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = prefill_cfg[
|
||||
"use_trtllm_ragged_deepseek_prefill"
|
||||
]
|
||||
|
||||
return vllm_config
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Backend Configuration
|
||||
# Prefill Backend Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Maps prefill backend names to attention config overrides.
|
||||
# FA backends set flash_attn_version and disable non-FA paths.
|
||||
# Non-FA backends enable their specific path and disable others.
|
||||
_PREFILL_BACKEND_CONFIG: dict[str, dict] = {
|
||||
"fa2": {
|
||||
"flash_attn_version": 2,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"fa3": {
|
||||
"flash_attn_version": 3,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"fa4": {
|
||||
"flash_attn_version": 4,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"flashinfer": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": False,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"cudnn": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": True,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"trtllm": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_prefill_backend_config(prefill_backend: str) -> dict:
|
||||
"""Get attention config overrides for a prefill backend."""
|
||||
if prefill_backend not in _PREFILL_BACKEND_CONFIG:
|
||||
raise ValueError(
|
||||
f"Unknown prefill backend: {prefill_backend!r}. "
|
||||
f"Available: {list(_PREFILL_BACKEND_CONFIG.keys())}"
|
||||
)
|
||||
return _PREFILL_BACKEND_CONFIG[prefill_backend]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decode Backend Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@@ -203,6 +282,7 @@ def _get_backend_config(backend: str) -> dict:
|
||||
Returns:
|
||||
Dict with backend configuration
|
||||
"""
|
||||
from vllm.v1.attention.backend import MultipleOf
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
try:
|
||||
@@ -219,8 +299,8 @@ def _get_backend_config(backend: str) -> dict:
|
||||
block_sizes = backend_class.get_supported_kernel_block_sizes()
|
||||
# Use first supported block size (backends typically support one for MLA)
|
||||
block_size = block_sizes[0] if block_sizes else None
|
||||
if hasattr(block_size, "value"):
|
||||
# Handle MultipleOf enum
|
||||
if isinstance(block_size, MultipleOf):
|
||||
# No fixed block size; fall back to config value
|
||||
block_size = None
|
||||
|
||||
# Check if sparse via class method if available
|
||||
@@ -676,16 +756,11 @@ def _run_single_benchmark(
|
||||
if is_sparse and indexer is not None:
|
||||
indexer.fill_random_indices(total_q, max_kv_len)
|
||||
|
||||
# Determine which forward method to use
|
||||
if is_sparse:
|
||||
# Sparse backends use forward_mqa
|
||||
# Determine which forward method to use based on metadata
|
||||
if metadata.decode is not None:
|
||||
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
|
||||
elif metadata.decode is not None:
|
||||
forward_fn = lambda: impl._forward_decode(
|
||||
decode_inputs, kv_cache, metadata, layer
|
||||
)
|
||||
elif metadata.prefill is not None:
|
||||
forward_fn = lambda: impl._forward_prefill(
|
||||
forward_fn = lambda: impl.forward_mha(
|
||||
prefill_inputs["q"],
|
||||
prefill_inputs["k_c_normed"],
|
||||
prefill_inputs["k_pe"],
|
||||
@@ -732,6 +807,7 @@ def _run_mla_benchmark_batched(
|
||||
backend: str,
|
||||
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
|
||||
index_topk: int = 2048,
|
||||
prefill_backend: str | None = None,
|
||||
) -> list[BenchmarkResult]:
|
||||
"""
|
||||
Unified batched MLA benchmark runner for all backends.
|
||||
@@ -743,11 +819,13 @@ def _run_mla_benchmark_batched(
|
||||
to avoid setup/teardown overhead.
|
||||
|
||||
Args:
|
||||
backend: Backend name
|
||||
backend: Backend name (decode backend used for impl construction)
|
||||
configs_with_params: List of (config, threshold, num_splits) tuples
|
||||
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
|
||||
- num_splits: num_kv_splits (CUTLASS only)
|
||||
index_topk: Topk value for sparse MLA backends (default 2048)
|
||||
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
|
||||
When set, forces the specified FlashAttention version for prefill.
|
||||
|
||||
Returns:
|
||||
List of BenchmarkResult objects
|
||||
@@ -780,11 +858,25 @@ def _run_mla_benchmark_batched(
|
||||
block_size=block_size,
|
||||
mla_dims=mla_dims, # Use custom dims from config or default
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
prefill_backend=prefill_backend,
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Clear cached prefill backend detection functions so they re-evaluate
|
||||
# with the current VllmConfig. These are @functools.cache decorated and
|
||||
# would otherwise return stale results from a previous backend's config.
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
use_cudnn_prefill,
|
||||
use_flashinfer_prefill,
|
||||
use_trtllm_ragged_deepseek_prefill,
|
||||
)
|
||||
|
||||
use_flashinfer_prefill.cache_clear()
|
||||
use_cudnn_prefill.cache_clear()
|
||||
use_trtllm_ragged_deepseek_prefill.cache_clear()
|
||||
|
||||
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
|
||||
impl, layer, builder_instance, indexer = _create_backend_impl(
|
||||
backend_cfg,
|
||||
@@ -794,6 +886,38 @@ def _run_mla_benchmark_batched(
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
)
|
||||
|
||||
# Verify the actual prefill backend matches what was requested
|
||||
if prefill_backend is not None:
|
||||
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
||||
fa_version = prefill_cfg["flash_attn_version"]
|
||||
|
||||
if fa_version is not None:
|
||||
# FA backend: verify the impl's FA version
|
||||
actual_fa_version = getattr(impl, "vllm_flash_attn_version", None)
|
||||
if actual_fa_version != fa_version:
|
||||
raise RuntimeError(
|
||||
f"Prefill backend '{prefill_backend}' requested FA "
|
||||
f"version {fa_version}, but the impl is using FA "
|
||||
f"version {actual_fa_version}. Check "
|
||||
f"vllm/v1/attention/backends/fa_utils.py."
|
||||
)
|
||||
else:
|
||||
# Non-FA backend: verify the builder picked the right path
|
||||
expected_flags = {
|
||||
"flashinfer": "_use_fi_prefill",
|
||||
"cudnn": "_use_cudnn_prefill",
|
||||
"trtllm": "_use_trtllm_ragged_prefill",
|
||||
}
|
||||
flag_name = expected_flags.get(prefill_backend)
|
||||
if flag_name and not getattr(builder_instance, flag_name, False):
|
||||
raise RuntimeError(
|
||||
f"Prefill backend '{prefill_backend}' was requested "
|
||||
f"but the metadata builder did not enable it. This "
|
||||
f"usually means a dependency is missing (e.g., "
|
||||
f"flashinfer not installed) or the platform doesn't "
|
||||
f"support it."
|
||||
)
|
||||
|
||||
# Run each benchmark with the shared impl
|
||||
for config, threshold, num_splits in configs_with_params:
|
||||
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
|
||||
@@ -844,6 +968,7 @@ def run_mla_benchmark(
|
||||
reorder_batch_threshold: int | None = None,
|
||||
num_kv_splits: int | None = None,
|
||||
index_topk: int = 2048,
|
||||
prefill_backend: str | None = None,
|
||||
) -> BenchmarkResult | list[BenchmarkResult]:
|
||||
"""
|
||||
Unified MLA benchmark runner for all backends.
|
||||
@@ -861,6 +986,8 @@ def run_mla_benchmark(
|
||||
(single config mode only)
|
||||
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
|
||||
index_topk: Topk value for sparse MLA backends (default 2048)
|
||||
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
|
||||
When set, forces the specified FlashAttention version for prefill.
|
||||
|
||||
Returns:
|
||||
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
|
||||
@@ -884,7 +1011,9 @@ def run_mla_benchmark(
|
||||
return_single = True
|
||||
|
||||
# Use unified batched execution
|
||||
results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk)
|
||||
results = _run_mla_benchmark_batched(
|
||||
backend, configs_with_params, index_topk, prefill_backend=prefill_backend
|
||||
)
|
||||
|
||||
# Return single result or list based on input
|
||||
return results[0] if return_single else results
|
||||
|
||||
@@ -39,7 +39,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2
|
||||
GIT_TAG 1488682bb545f7d020e958a33116b1419d1cfc83
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@@ -30,14 +30,14 @@ class AttentionConfig:
|
||||
use_cudnn_prefill: bool = False
|
||||
"""Whether to use cudnn prefill."""
|
||||
|
||||
use_trtllm_ragged_deepseek_prefill: bool = True
|
||||
use_trtllm_ragged_deepseek_prefill: bool = False
|
||||
"""Whether to use TRTLLM ragged deepseek prefill."""
|
||||
|
||||
use_trtllm_attention: bool | None = None
|
||||
"""If set to True/False, use or don't use the TRTLLM attention backend
|
||||
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
||||
|
||||
disable_flashinfer_prefill: bool = False
|
||||
disable_flashinfer_prefill: bool = True
|
||||
"""Whether to disable flashinfer prefill."""
|
||||
|
||||
disable_flashinfer_q_quantization: bool = False
|
||||
|
||||
@@ -1282,8 +1282,6 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
|
||||
|
||||
@functools.cache
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
# For blackwell default to flashinfer prefill if it's available since
|
||||
# it is faster than FA2.
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
@@ -2154,13 +2152,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
# not support different headdims.
|
||||
# FA3 on Hopper (SM90) and FA4 natively handle diff headdims.
|
||||
device_capability = current_platform.get_device_capability()
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and device_capability is not None
|
||||
and device_capability[0] == 9
|
||||
(
|
||||
self.vllm_flash_attn_version == 3
|
||||
and device_capability is not None
|
||||
and device_capability[0] == 9
|
||||
)
|
||||
or self.vllm_flash_attn_version == 4
|
||||
)
|
||||
|
||||
self.dcp_world_size: int = -1
|
||||
|
||||
@@ -125,11 +125,14 @@ def get_flash_attn_version(
|
||||
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
|
||||
# supported head dimensions.
|
||||
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
|
||||
# Exception: hdim 192 is supported for MLA's diff-headdim case
|
||||
# (qk=192, v=128), added upstream in commits 1a15733e/1b36ab19.
|
||||
if (
|
||||
fa_version == 4
|
||||
and device_capability.major >= 10
|
||||
and head_size is not None
|
||||
and head_size > 128
|
||||
and head_size != 192
|
||||
):
|
||||
logger.warning_once(
|
||||
"FA4 on Blackwell does not support head_size=%d due to TMEM "
|
||||
|
||||
Reference in New Issue
Block a user