[Misc] Log the reason for falling back to FlexAttention (#20699)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-14 19:16:51 +08:00
committed by GitHub
parent a4851cfe68
commit e8cc53af5e
10 changed files with 105 additions and 33 deletions

View File

@@ -3,6 +3,7 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache
from typing import Generator, Optional, Union
@@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend
def supports_head_size(
@dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
def is_attn_backend_supported(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
) -> bool:
dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
return False
if not allow_import_error:
raise
return _IsSupported(can_import=False, head_size=False, dtype=False)
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None):
return head_size in get_supported_head_sizes()
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
is_head_size_supported = head_size in get_supported_head_sizes()
elif validate_head_size := getattr(attn_backend, "validate_head_size",
None):
try:
validate_head_size(head_size)
return True
is_head_size_supported = True
except Exception:
return False
is_head_size_supported = False
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"dtype validation")
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
def get_attn_backend(