diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index f0497a872..fe2e31252 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) from vllm.platforms import current_platform +from vllm.v1.attention.backend import is_quantized_kv_cache logger = init_logger(__name__) @@ -52,11 +53,14 @@ class BaseKVCacheMethod(QuantizeMethodBase): assert not hasattr(layer, "prob_scale") return - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to # calculate them on the fly. - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if ( + is_quantized_kv_cache(layer.kv_cache_dtype) + and not layer.calculate_kv_scales + ): if layer.k_scale > 0.0 and layer.v_scale > 0.0: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 949e9f41e..46465a482 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -16,6 +16,7 @@ import torch from vllm import envs from vllm.logger import init_logger +from vllm.v1.attention.backend import is_quantized_kv_cache from vllm.v1.attention.backends.registry import AttentionBackendEnum from .interface import CpuArchEnum, Platform, PlatformEnum @@ -198,13 +199,13 @@ class CpuPlatform(Platform): if ( scheduler_config.enable_chunked_prefill or cache_config.enable_prefix_caching - ) and cache_config.cache_dtype != "auto": + ) and is_quantized_kv_cache(cache_config.cache_dtype): raise RuntimeError( "Chunked-prefill and prefix-cache on the CPU " "backend is not compatible with FP8 KV cache." ) - if cache_config.cache_dtype != "auto": + if cache_config.cache_dtype.startswith("fp8"): logger.warning( "CPU backend doesn't support KV cache quantization fallback to auto." ) diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index b4dcea105..383be8571 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -51,7 +51,7 @@ class AttentionBackend(ABC): # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"] + supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: @@ -747,7 +747,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: - return kv_cache_dtype != "auto" + return kv_cache_dtype.startswith("fp8") def subclass_attention_backend( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 6fec5001b..900074ce2 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -151,7 +151,7 @@ class FlashAttentionBackend(AttentionBackend): return True if kv_cache_dtype.startswith("fp8"): return flash_attn_supports_fp8() - return kv_cache_dtype in ["auto"] + return kv_cache_dtype in ["auto", "bfloat16"] @classmethod def supports_sink(cls) -> bool: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 69d24deb2..afefc164f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -281,6 +281,7 @@ class FlashInferBackend(AttentionBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", + "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 48c8ac6a8..687e2ba1d 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -80,7 +80,7 @@ class FlexAttentionBackend(AttentionBackend): torch.bfloat16, torch.float32, ] - supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"] @staticmethod def get_name() -> str: diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index d898ee346..a8ba10080 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -38,6 +38,7 @@ class CutlassMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", + "bfloat16", "fp8", "fp8_e4m3", ] diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 1eb6f0df4..99c3ce55b 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -43,7 +43,10 @@ logger = init_logger(__name__) class FlashAttnMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 1905aebcc..236e293ad 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -38,6 +38,7 @@ class FlashInferMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", + "bfloat16", "fp8", "fp8_e4m3", ] diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 9047d9121..2dd8f4a51 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -48,6 +48,7 @@ class FlashMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", + "bfloat16", "fp8", "fp8_e4m3", ] diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index fd68e54e8..2f77e3c03 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -76,7 +76,11 @@ structured as: class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] - supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + "fp8_ds_mla", + ] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 0980907e1..84e025dcd 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -28,7 +28,10 @@ logger = init_logger(__name__) class TritonMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] @staticmethod def get_name() -> str: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 06cb17211..75a82bfb1 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -259,6 +259,7 @@ class TritonAttentionBackend(AttentionBackend): ] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", + "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2",