diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index 189bf3d4f..00107cd7f 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -40,6 +40,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: if current_platform.is_xpu(): return 2 + if current_platform.is_rocm(): + # ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg + return None try: from vllm.vllm_flash_attn.flash_attn_interface import ( fa_version_unsupported_reason,