[Misc] Add float16 to CacheDType (#37199)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -13,6 +13,7 @@ logger = init_logger(__name__)
|
||||
|
||||
CacheDType = Literal[
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -51,7 +51,11 @@ class AttentionBackend(ABC):
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
# Does attention's forward() include kv cache update?
|
||||
forward_includes_kv_cache_update: bool = True
|
||||
|
||||
@@ -64,6 +64,11 @@ logger = init_logger(__name__)
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
@@ -164,7 +169,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
return True
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
return flash_attn_supports_fp8()
|
||||
return kv_cache_dtype in ["auto", "bfloat16"]
|
||||
return kv_cache_dtype in ["auto", "float16", "bfloat16"]
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
|
||||
@@ -291,6 +291,7 @@ class FlashInferBackend(AttentionBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -80,7 +80,11 @@ class FlexAttentionBackend(AttentionBackend):
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ class CutlassMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -46,6 +46,7 @@ class FlashAttnMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ class FlashInferMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -62,6 +62,7 @@ class FlashInferMLASparseBackend(AttentionBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -49,6 +49,7 @@ class FlashMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -26,6 +26,7 @@ class AiterMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -82,6 +82,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class TritonMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -38,6 +38,7 @@ class XPUMLASparseBackend(AttentionBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
|
||||
@@ -736,6 +736,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -166,6 +166,7 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -31,6 +32,11 @@ logger = init_logger(__name__)
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
]
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -263,6 +263,7 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
|
||||
Reference in New Issue
Block a user