[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
@@ -445,31 +446,13 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
@@ -531,8 +514,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
AiterFlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
|
||||
Reference in New Issue
Block a user