[V1] Support any head size for FlexAttention backend (#20467)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-07 00:54:36 +08:00
committed by GitHub
parent e202dd2736
commit 9fb52e523a
20 changed files with 202 additions and 118 deletions

View File

@@ -314,10 +314,21 @@ class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_VLLM_V1"
@@ -428,14 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = \
AiterFlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by "
"AiterFlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
AiterFlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "