[V1] Support any head size for FlexAttention backend (#20467)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import Generator, Optional, Type
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
def supports_head_size(
|
||||
attn_backend: Union[str, type[AttentionBackend]],
|
||||
head_size: int,
|
||||
) -> bool:
|
||||
if isinstance(attn_backend, str):
|
||||
try:
|
||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(attn_backend,
|
||||
"get_supported_head_sizes", None):
|
||||
return head_size in get_supported_head_sizes()
|
||||
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"head size validation")
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
@@ -87,7 +114,7 @@ def get_attn_backend(
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
@@ -115,7 +142,7 @@ def _cached_get_attn_backend(
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
) -> type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
|
||||
Reference in New Issue
Block a user