[Attention] Use FA4 for MLA prefill (#34732)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -62,6 +62,7 @@ def create_minimal_vllm_config(
|
||||
max_num_seqs: int = 256,
|
||||
mla_dims: dict | None = None,
|
||||
index_topk: int | None = None,
|
||||
prefill_backend: str | None = None,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create minimal VllmConfig for MLA benchmarks.
|
||||
@@ -75,6 +76,9 @@ def create_minimal_vllm_config(
|
||||
setup_mla_dims(model_name)
|
||||
index_topk: Optional topk value for sparse MLA backends. If provided,
|
||||
the config will include index_topk for sparse attention.
|
||||
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
|
||||
"cudnn", "trtllm"). Configures the attention config to
|
||||
force the specified prefill backend.
|
||||
|
||||
Returns:
|
||||
VllmConfig for benchmarking
|
||||
@@ -163,7 +167,7 @@ def create_minimal_vllm_config(
|
||||
|
||||
compilation_config = CompilationConfig()
|
||||
|
||||
return VllmConfig(
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
@@ -171,9 +175,84 @@ def create_minimal_vllm_config(
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
|
||||
if prefill_backend is not None:
|
||||
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
||||
if prefill_cfg["flash_attn_version"] is not None:
|
||||
vllm_config.attention_config.flash_attn_version = prefill_cfg[
|
||||
"flash_attn_version"
|
||||
]
|
||||
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
|
||||
"disable_flashinfer_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
|
||||
"use_cudnn_prefill"
|
||||
]
|
||||
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = prefill_cfg[
|
||||
"use_trtllm_ragged_deepseek_prefill"
|
||||
]
|
||||
|
||||
return vllm_config
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Backend Configuration
|
||||
# Prefill Backend Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Maps prefill backend names to attention config overrides.
|
||||
# FA backends set flash_attn_version and disable non-FA paths.
|
||||
# Non-FA backends enable their specific path and disable others.
|
||||
_PREFILL_BACKEND_CONFIG: dict[str, dict] = {
|
||||
"fa2": {
|
||||
"flash_attn_version": 2,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"fa3": {
|
||||
"flash_attn_version": 3,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"fa4": {
|
||||
"flash_attn_version": 4,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"flashinfer": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": False,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"cudnn": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": True,
|
||||
"use_trtllm_ragged_deepseek_prefill": False,
|
||||
},
|
||||
"trtllm": {
|
||||
"flash_attn_version": None,
|
||||
"disable_flashinfer_prefill": True,
|
||||
"use_cudnn_prefill": False,
|
||||
"use_trtllm_ragged_deepseek_prefill": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_prefill_backend_config(prefill_backend: str) -> dict:
|
||||
"""Get attention config overrides for a prefill backend."""
|
||||
if prefill_backend not in _PREFILL_BACKEND_CONFIG:
|
||||
raise ValueError(
|
||||
f"Unknown prefill backend: {prefill_backend!r}. "
|
||||
f"Available: {list(_PREFILL_BACKEND_CONFIG.keys())}"
|
||||
)
|
||||
return _PREFILL_BACKEND_CONFIG[prefill_backend]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decode Backend Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@@ -203,6 +282,7 @@ def _get_backend_config(backend: str) -> dict:
|
||||
Returns:
|
||||
Dict with backend configuration
|
||||
"""
|
||||
from vllm.v1.attention.backend import MultipleOf
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
try:
|
||||
@@ -219,8 +299,8 @@ def _get_backend_config(backend: str) -> dict:
|
||||
block_sizes = backend_class.get_supported_kernel_block_sizes()
|
||||
# Use first supported block size (backends typically support one for MLA)
|
||||
block_size = block_sizes[0] if block_sizes else None
|
||||
if hasattr(block_size, "value"):
|
||||
# Handle MultipleOf enum
|
||||
if isinstance(block_size, MultipleOf):
|
||||
# No fixed block size; fall back to config value
|
||||
block_size = None
|
||||
|
||||
# Check if sparse via class method if available
|
||||
@@ -676,16 +756,11 @@ def _run_single_benchmark(
|
||||
if is_sparse and indexer is not None:
|
||||
indexer.fill_random_indices(total_q, max_kv_len)
|
||||
|
||||
# Determine which forward method to use
|
||||
if is_sparse:
|
||||
# Sparse backends use forward_mqa
|
||||
# Determine which forward method to use based on metadata
|
||||
if metadata.decode is not None:
|
||||
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
|
||||
elif metadata.decode is not None:
|
||||
forward_fn = lambda: impl._forward_decode(
|
||||
decode_inputs, kv_cache, metadata, layer
|
||||
)
|
||||
elif metadata.prefill is not None:
|
||||
forward_fn = lambda: impl._forward_prefill(
|
||||
forward_fn = lambda: impl.forward_mha(
|
||||
prefill_inputs["q"],
|
||||
prefill_inputs["k_c_normed"],
|
||||
prefill_inputs["k_pe"],
|
||||
@@ -732,6 +807,7 @@ def _run_mla_benchmark_batched(
|
||||
backend: str,
|
||||
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
|
||||
index_topk: int = 2048,
|
||||
prefill_backend: str | None = None,
|
||||
) -> list[BenchmarkResult]:
|
||||
"""
|
||||
Unified batched MLA benchmark runner for all backends.
|
||||
@@ -743,11 +819,13 @@ def _run_mla_benchmark_batched(
|
||||
to avoid setup/teardown overhead.
|
||||
|
||||
Args:
|
||||
backend: Backend name
|
||||
backend: Backend name (decode backend used for impl construction)
|
||||
configs_with_params: List of (config, threshold, num_splits) tuples
|
||||
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
|
||||
- num_splits: num_kv_splits (CUTLASS only)
|
||||
index_topk: Topk value for sparse MLA backends (default 2048)
|
||||
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
|
||||
When set, forces the specified FlashAttention version for prefill.
|
||||
|
||||
Returns:
|
||||
List of BenchmarkResult objects
|
||||
@@ -780,11 +858,25 @@ def _run_mla_benchmark_batched(
|
||||
block_size=block_size,
|
||||
mla_dims=mla_dims, # Use custom dims from config or default
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
prefill_backend=prefill_backend,
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Clear cached prefill backend detection functions so they re-evaluate
|
||||
# with the current VllmConfig. These are @functools.cache decorated and
|
||||
# would otherwise return stale results from a previous backend's config.
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
use_cudnn_prefill,
|
||||
use_flashinfer_prefill,
|
||||
use_trtllm_ragged_deepseek_prefill,
|
||||
)
|
||||
|
||||
use_flashinfer_prefill.cache_clear()
|
||||
use_cudnn_prefill.cache_clear()
|
||||
use_trtllm_ragged_deepseek_prefill.cache_clear()
|
||||
|
||||
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
|
||||
impl, layer, builder_instance, indexer = _create_backend_impl(
|
||||
backend_cfg,
|
||||
@@ -794,6 +886,38 @@ def _run_mla_benchmark_batched(
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
)
|
||||
|
||||
# Verify the actual prefill backend matches what was requested
|
||||
if prefill_backend is not None:
|
||||
prefill_cfg = get_prefill_backend_config(prefill_backend)
|
||||
fa_version = prefill_cfg["flash_attn_version"]
|
||||
|
||||
if fa_version is not None:
|
||||
# FA backend: verify the impl's FA version
|
||||
actual_fa_version = getattr(impl, "vllm_flash_attn_version", None)
|
||||
if actual_fa_version != fa_version:
|
||||
raise RuntimeError(
|
||||
f"Prefill backend '{prefill_backend}' requested FA "
|
||||
f"version {fa_version}, but the impl is using FA "
|
||||
f"version {actual_fa_version}. Check "
|
||||
f"vllm/v1/attention/backends/fa_utils.py."
|
||||
)
|
||||
else:
|
||||
# Non-FA backend: verify the builder picked the right path
|
||||
expected_flags = {
|
||||
"flashinfer": "_use_fi_prefill",
|
||||
"cudnn": "_use_cudnn_prefill",
|
||||
"trtllm": "_use_trtllm_ragged_prefill",
|
||||
}
|
||||
flag_name = expected_flags.get(prefill_backend)
|
||||
if flag_name and not getattr(builder_instance, flag_name, False):
|
||||
raise RuntimeError(
|
||||
f"Prefill backend '{prefill_backend}' was requested "
|
||||
f"but the metadata builder did not enable it. This "
|
||||
f"usually means a dependency is missing (e.g., "
|
||||
f"flashinfer not installed) or the platform doesn't "
|
||||
f"support it."
|
||||
)
|
||||
|
||||
# Run each benchmark with the shared impl
|
||||
for config, threshold, num_splits in configs_with_params:
|
||||
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
|
||||
@@ -844,6 +968,7 @@ def run_mla_benchmark(
|
||||
reorder_batch_threshold: int | None = None,
|
||||
num_kv_splits: int | None = None,
|
||||
index_topk: int = 2048,
|
||||
prefill_backend: str | None = None,
|
||||
) -> BenchmarkResult | list[BenchmarkResult]:
|
||||
"""
|
||||
Unified MLA benchmark runner for all backends.
|
||||
@@ -861,6 +986,8 @@ def run_mla_benchmark(
|
||||
(single config mode only)
|
||||
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
|
||||
index_topk: Topk value for sparse MLA backends (default 2048)
|
||||
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
|
||||
When set, forces the specified FlashAttention version for prefill.
|
||||
|
||||
Returns:
|
||||
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
|
||||
@@ -884,7 +1011,9 @@ def run_mla_benchmark(
|
||||
return_single = True
|
||||
|
||||
# Use unified batched execution
|
||||
results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk)
|
||||
results = _run_mla_benchmark_batched(
|
||||
backend, configs_with_params, index_topk, prefill_backend=prefill_backend
|
||||
)
|
||||
|
||||
# Return single result or list based on input
|
||||
return results[0] if return_single else results
|
||||
|
||||
Reference in New Issue
Block a user