[V1] Support any head size for FlexAttention backend (#20467)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-07 00:54:36 +08:00
committed by GitHub
parent e202dd2736
commit 9fb52e523a
20 changed files with 202 additions and 118 deletions

View File

@@ -4,7 +4,7 @@
import os
from contextlib import contextmanager
from functools import cache
from typing import Generator, Optional, Type
from typing import Generator, Optional, Union
import torch
@@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend
def supports_head_size(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
) -> bool:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
return False
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None):
return head_size in get_supported_head_sizes()
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
try:
validate_head_size(head_size)
return True
except Exception:
return False
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
@@ -87,7 +114,7 @@ def get_attn_backend(
is_attention_free: bool,
is_blocksparse: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
@@ -115,7 +142,7 @@ def _cached_get_attn_backend(
is_blocksparse: bool = False,
use_v1: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
) -> type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (