[Attention][AMD] Make flash-attn optional (#30361)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This commit is contained in:
@@ -23,11 +23,13 @@ elif current_platform.is_xpu():
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Rocm platform requires upstream flash-attn "
|
||||
"to be installed. Please install flash-attn first."
|
||||
) from e
|
||||
except ImportError:
|
||||
|
||||
def flash_attn_varlen_func(*args, **kwargs):
|
||||
raise ImportError(
|
||||
"ROCm platform requires upstream flash-attn "
|
||||
"to be installed. Please install flash-attn first."
|
||||
)
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
|
||||
Reference in New Issue
Block a user