[Bugfix][Attention] Explicitly report support for kv_cache_dtype bfloat16 (#32795)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-22 14:05:18 -05:00
committed by GitHub
parent 744ef30484
commit 955b43a5a5
13 changed files with 31 additions and 11 deletions

View File

@@ -51,7 +51,7 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
@@ -747,7 +747,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"
return kv_cache_dtype.startswith("fp8")
def subclass_attention_backend(

View File

@@ -151,7 +151,7 @@ class FlashAttentionBackend(AttentionBackend):
return True
if kv_cache_dtype.startswith("fp8"):
return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto"]
return kv_cache_dtype in ["auto", "bfloat16"]
@classmethod
def supports_sink(cls) -> bool:

View File

@@ -281,6 +281,7 @@ class FlashInferBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",

View File

@@ -80,7 +80,7 @@ class FlexAttentionBackend(AttentionBackend):
torch.bfloat16,
torch.float32,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
@staticmethod
def get_name() -> str:

View File

@@ -38,6 +38,7 @@ class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]

View File

@@ -43,7 +43,10 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:

View File

@@ -38,6 +38,7 @@ class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]

View File

@@ -48,6 +48,7 @@ class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]

View File

@@ -76,7 +76,11 @@ structured as:
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8_ds_mla",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:

View File

@@ -28,7 +28,10 @@ logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_name() -> str:

View File

@@ -259,6 +259,7 @@ class TritonAttentionBackend(AttentionBackend):
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",