fix(ROCm): Make flash_attn import optional in MLA attention (#33511)
Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
@@ -919,10 +919,20 @@ try:
|
||||
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
# For rocm use upstream flash attention
|
||||
if current_platform.is_rocm():
|
||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||
is_vllm_fa = False
|
||||
flash_attn_varlen_func = None # type: ignore[assignment]
|
||||
# On ROCm, vllm_flash_attn is not available, try upstream flash_attn instead.
|
||||
# On CUDA, vllm_flash_attn should always be available (built with vLLM),
|
||||
# so we don't attempt the fallback there.
|
||||
if current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"flash_attn not available on ROCm; "
|
||||
"MLA models using TRITON_MLA will require flash_attn. "
|
||||
"AITER_MLA backends use aiter kernels instead."
|
||||
)
|
||||
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
@@ -1917,6 +1927,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
|
||||
self._pad_v = False
|
||||
else: # Use FlashAttention
|
||||
if flash_attn_varlen_func is None:
|
||||
raise RuntimeError(
|
||||
"MLA attention requires FlashAttention but it is not "
|
||||
"available. Please install flash_attn or use "
|
||||
"--attention-backend ROCM_AITER_MLA."
|
||||
)
|
||||
logger.info_once("Using FlashAttention prefill for MLA", scope="local")
|
||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
|
||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
|
||||
|
||||
Reference in New Issue
Block a user