[Bugfix][Attention] Explicitly report support for kv_cache_dtype bfloat16 (#32795)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import is_quantized_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -52,11 +53,14 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
assert not hasattr(layer, "prob_scale")
|
||||
return
|
||||
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
# calculate them on the fly.
|
||||
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
||||
if (
|
||||
is_quantized_kv_cache(layer.kv_cache_dtype)
|
||||
and not layer.calculate_kv_scales
|
||||
):
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
|
||||
@@ -16,6 +16,7 @@ import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import is_quantized_kv_cache
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
@@ -198,13 +199,13 @@ class CpuPlatform(Platform):
|
||||
if (
|
||||
scheduler_config.enable_chunked_prefill
|
||||
or cache_config.enable_prefix_caching
|
||||
) and cache_config.cache_dtype != "auto":
|
||||
) and is_quantized_kv_cache(cache_config.cache_dtype):
|
||||
raise RuntimeError(
|
||||
"Chunked-prefill and prefix-cache on the CPU "
|
||||
"backend is not compatible with FP8 KV cache."
|
||||
)
|
||||
|
||||
if cache_config.cache_dtype != "auto":
|
||||
if cache_config.cache_dtype.startswith("fp8"):
|
||||
logger.warning(
|
||||
"CPU backend doesn't support KV cache quantization fallback to auto."
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ class AttentionBackend(ABC):
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
@@ -747,7 +747,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
return kv_cache_dtype.startswith("fp8")
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
|
||||
@@ -151,7 +151,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
return True
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
return flash_attn_supports_fp8()
|
||||
return kv_cache_dtype in ["auto"]
|
||||
return kv_cache_dtype in ["auto", "bfloat16"]
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
|
||||
@@ -281,6 +281,7 @@ class FlashInferBackend(AttentionBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
|
||||
@@ -80,7 +80,7 @@ class FlexAttentionBackend(AttentionBackend):
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
|
||||
@@ -38,6 +38,7 @@ class CutlassMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@@ -43,7 +43,10 @@ logger = init_logger(__name__)
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
|
||||
@@ -38,6 +38,7 @@ class FlashInferMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@@ -48,6 +48,7 @@ class FlashMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@@ -76,7 +76,11 @@ structured as:
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8_ds_mla",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
|
||||
@@ -28,7 +28,10 @@ logger = init_logger(__name__)
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
|
||||
@@ -259,6 +259,7 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
|
||||
Reference in New Issue
Block a user