[Attention] Use FA4 for MLA prefill (#34732)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -1282,8 +1282,6 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
|
||||
|
||||
@functools.cache
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
# For blackwell default to flashinfer prefill if it's available since
|
||||
# it is faster than FA2.
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
@@ -2154,13 +2152,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
# not support different headdims.
|
||||
# FA3 on Hopper (SM90) and FA4 natively handle diff headdims.
|
||||
device_capability = current_platform.get_device_capability()
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and device_capability is not None
|
||||
and device_capability[0] == 9
|
||||
(
|
||||
self.vllm_flash_attn_version == 3
|
||||
and device_capability is not None
|
||||
and device_capability[0] == 9
|
||||
)
|
||||
or self.vllm_flash_attn_version == 4
|
||||
)
|
||||
|
||||
self.dcp_world_size: int = -1
|
||||
|
||||
Reference in New Issue
Block a user