diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 4859af43a..c31aa7b41 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -919,10 +919,20 @@ try: is_vllm_fa = True except ImportError: - # For rocm use upstream flash attention - if current_platform.is_rocm(): - from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] is_vllm_fa = False + flash_attn_varlen_func = None # type: ignore[assignment] + # On ROCm, vllm_flash_attn is not available, try upstream flash_attn instead. + # On CUDA, vllm_flash_attn should always be available (built with vLLM), + # so we don't attempt the fallback there. + if current_platform.is_rocm(): + try: + from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] + except ImportError: + logger.debug( + "flash_attn not available on ROCm; " + "MLA models using TRITON_MLA will require flash_attn. " + "AITER_MLA backends use aiter kernels instead." + ) def dynamic_per_batched_tensor_quant( @@ -1917,6 +1927,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._pad_v = False else: # Use FlashAttention + if flash_attn_varlen_func is None: + raise RuntimeError( + "MLA attention requires FlashAttention but it is not " + "available. Please install flash_attn or use " + "--attention-backend ROCM_AITER_MLA." + ) logger.info_once("Using FlashAttention prefill for MLA", scope="local") self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa