[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -18,12 +18,14 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@@ -147,25 +149,18 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
# Triton Attention supports any head size above 32
|
||||
if head_size < 32:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by TritonAttention."
|
||||
f"Head sizes need to be larger or equal 32 for this backend. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -195,6 +190,18 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
@@ -237,8 +244,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
TritonAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
|
||||
Reference in New Issue
Block a user