[Attention] Use FA4 for MLA prefill (#34732)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-03-12 12:10:17 -04:00
committed by GitHub
parent 85199f9681
commit f444c05c32
9 changed files with 413 additions and 78 deletions

View File

@@ -1282,8 +1282,6 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
@functools.cache
def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
@@ -2154,13 +2152,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
# not support different headdims.
# FA3 on Hopper (SM90) and FA4 natively handle diff headdims.
device_capability = current_platform.get_device_capability()
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and device_capability is not None
and device_capability[0] == 9
(
self.vllm_flash_attn_version == 3
and device_capability is not None
and device_capability[0] == 9
)
or self.vllm_flash_attn_version == 4
)
self.dcp_world_size: int = -1