[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -18,12 +18,14 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
@@ -147,25 +149,18 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
# Triton Attention supports any head size above 32
if head_size < 32:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention."
f"Head sizes need to be larger or equal 32 for this backend. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]
@staticmethod
def get_name() -> str:
@@ -195,6 +190,18 @@ class TritonAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
@@ -237,8 +244,6 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
TritonAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "