fix(ROCm): Make flash_attn import optional in MLA attention (#33511)

Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
Rabi Mishra
2026-02-06 07:52:53 +05:30
committed by GitHub
parent 5819ca8944
commit 20d7454c9b

View File

@@ -919,10 +919,20 @@ try:
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
if current_platform.is_rocm():
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
is_vllm_fa = False
flash_attn_varlen_func = None # type: ignore[assignment]
# On ROCm, vllm_flash_attn is not available, try upstream flash_attn instead.
# On CUDA, vllm_flash_attn should always be available (built with vLLM),
# so we don't attempt the fallback there.
if current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
except ImportError:
logger.debug(
"flash_attn not available on ROCm; "
"MLA models using TRITON_MLA will require flash_attn. "
"AITER_MLA backends use aiter kernels instead."
)
def dynamic_per_batched_tensor_quant(
@@ -1917,6 +1927,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
self._pad_v = False
else: # Use FlashAttention
if flash_attn_varlen_func is None:
raise RuntimeError(
"MLA attention requires FlashAttention but it is not "
"available. Please install flash_attn or use "
"--attention-backend ROCM_AITER_MLA."
)
logger.info_once("Using FlashAttention prefill for MLA", scope="local")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa