[Bugfix][Attention] Explicitly report support for kv_cache_dtype bfloat16 (#32795)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-22 14:05:18 -05:00
committed by GitHub
parent 744ef30484
commit 955b43a5a5
13 changed files with 31 additions and 11 deletions

View File

@@ -51,7 +51,7 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
@@ -747,7 +747,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"
return kv_cache_dtype.startswith("fp8")
def subclass_attention_backend(