From 116f4be405ff8aaadbc885d8d527d1694c7fcf0a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 1 Apr 2026 00:08:01 -0400 Subject: [PATCH] [1/N][Cleanup] Standardize on use of `is_quantized_kv_cache` (#38659) Signed-off-by: Matthew Bonanni --- vllm/config/cache.py | 3 ++- .../layers/attention/mla_attention.py | 9 +++++---- .../layers/quantization/kv_cache.py | 2 +- .../models/extract_hidden_states.py | 3 +-- vllm/platforms/cpu.py | 4 ++-- vllm/platforms/cuda.py | 3 ++- vllm/utils/torch_utils.py | 4 ++++ vllm/v1/attention/backend.py | 4 ---- vllm/v1/attention/backends/cpu_attn.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 10 +++++----- .../attention/backends/flash_attn_diffkv.py | 3 ++- vllm/v1/attention/backends/flashinfer.py | 19 +++++++++---------- vllm/v1/attention/backends/flex_attention.py | 3 +-- vllm/v1/attention/backends/mla/cutlass_mla.py | 2 +- .../attention/backends/mla/flashattn_mla.py | 4 ++-- .../attention/backends/mla/flashinfer_mla.py | 6 +++--- .../backends/mla/flashinfer_mla_sparse.py | 5 +++-- vllm/v1/attention/backends/mla/flashmla.py | 9 ++++++--- .../attention/backends/mla/flashmla_sparse.py | 3 ++- vllm/v1/attention/backends/mla/triton_mla.py | 2 +- .../attention/backends/mla/xpu_mla_sparse.py | 3 ++- vllm/v1/attention/backends/rocm_aiter_fa.py | 15 ++++++++------- .../backends/rocm_aiter_unified_attn.py | 5 +++-- vllm/v1/attention/backends/rocm_attn.py | 7 ++++--- vllm/v1/attention/backends/triton_attn.py | 9 +++++---- .../ops/triton_reshape_and_cache_flash.py | 19 +++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 3 ++- vllm/v1/worker/gpu_worker.py | 4 ++-- 28 files changed, 90 insertions(+), 75 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 49c8868e7..1fdce002e 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator from vllm.config.utils import config from vllm.logger import init_logger +from vllm.utils.torch_utils import is_quantized_kv_cache logger = init_logger(__name__) @@ -236,7 +237,7 @@ class CacheConfig: @field_validator("cache_dtype", mode="after") @classmethod def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: - if cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(cache_dtype): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 4977c62b9..0be46fbbc 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -241,6 +241,7 @@ from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory from vllm.utils.math_utils import cdiv, round_down from vllm.utils.torch_utils import ( direct_register_custom_op, + is_quantized_kv_cache, kv_cache_dtype_str_to_dtype, ) from vllm.v1.attention.backend import ( @@ -342,7 +343,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): # Automatically convert fp8 kv-cache format to "fp8_ds_mla" if ( self.attn_backend.get_name() == "FLASHMLA_SPARSE" - and kv_cache_dtype.startswith("fp8") + and is_quantized_kv_cache(kv_cache_dtype) and kv_cache_dtype != "fp8_ds_mla" ): assert cache_config is not None @@ -356,7 +357,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): if ( self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE" - and kv_cache_dtype.startswith("fp8") + and is_quantized_kv_cache(kv_cache_dtype) ): logger.info_once( "Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla " @@ -571,7 +572,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): if self.impl.dcp_world_size == -1: self.impl.dcp_world_size = get_dcp_group().world_size - fp8_attention = self.kv_cache_dtype.startswith("fp8") + fp8_attention = is_quantized_kv_cache(self.kv_cache_dtype) num_actual_toks = attn_metadata.num_actual_tokens @@ -1434,7 +1435,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): is enabled, else model dtype. """ use_fp8 = ( - vllm_config.cache_config.cache_dtype.startswith("fp8") + is_quantized_kv_cache(vllm_config.cache_config.cache_dtype) and vllm_config.attention_config.use_prefill_query_quantization and backend_supports_prefill_query_quantization() ) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index fe2e31252..2fb67aacc 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) from vllm.platforms import current_platform -from vllm.v1.attention.backend import is_quantized_kv_cache +from vllm.utils.torch_utils import is_quantized_kv_cache logger = init_logger(__name__) diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py index d969441ac..608e93d6a 100644 --- a/vllm/model_executor/models/extract_hidden_states.py +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -23,14 +23,13 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import ( ) from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.models.utils import maybe_prefix -from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype +from vllm.utils.torch_utils import is_quantized_kv_cache, kv_cache_dtype_str_to_dtype from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, - is_quantized_kv_cache, ) from vllm.v1.kv_cache_interface import ( AttentionSpec, diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 7fbad3e4c..ff6b22e55 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -16,7 +16,7 @@ import torch from vllm import envs from vllm.logger import init_logger -from vllm.v1.attention.backend import is_quantized_kv_cache +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backends.registry import AttentionBackendEnum from .interface import CpuArchEnum, Platform, PlatformEnum @@ -183,7 +183,7 @@ class CpuPlatform(Platform): "backend is not compatible with FP8 KV cache." ) - if cache_config.cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(cache_config.cache_dtype): logger.warning( "CPU backend doesn't support KV cache quantization fallback to auto." ) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 73bfbeef1..10dd1b869 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -23,6 +23,7 @@ import vllm._C_stable_libtorch # noqa import vllm.envs as envs from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backends.registry import AttentionBackendEnum from .interface import DeviceCapability, Platform, PlatformEnum @@ -87,7 +88,7 @@ def _get_backend_priorities( # Sparse MLA backend priorities # See https://github.com/vllm-project/vllm/issues/35807 for # benchmark results - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + if kv_cache_dtype is not None and is_quantized_kv_cache(kv_cache_dtype): # Prefer FlashInfer for fp8 kv cache sparse_backends = [ AttentionBackendEnum.FLASHINFER_MLA_SPARSE, diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index bd9741024..59c19a56e 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -61,6 +61,10 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = { T = TypeVar("T") +def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: + return kv_cache_dtype.startswith("fp8") + + def is_strictly_contiguous(t: torch.Tensor) -> bool: """ Check if tensor is contiguous AND has no degenerate strides. diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index ec21b0fe9..32fac520c 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -954,10 +954,6 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): ) -def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: - return kv_cache_dtype.startswith("fp8") - - def subclass_attention_backend( name_prefix: str, attention_backend_cls: type[AttentionBackend], diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 5fa3844c8..1df0fe654 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -9,6 +9,7 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, @@ -16,7 +17,6 @@ from vllm.v1.attention.backend import ( AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, - is_quantized_kv_cache, ) from vllm.v1.attention.backends.utils import ( split_decodes_and_prefills, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e0c7c7287..d72c2aeb6 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,12 +10,12 @@ import numpy as np import torch from vllm.model_executor.layers.attention import Attention +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionType, MultipleOf, - is_quantized_kv_cache, ) from vllm.v1.attention.backends.fa_utils import ( flash_attn_supports_fp8, @@ -177,7 +177,7 @@ class FlashAttentionBackend(AttentionBackend): def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: if kv_cache_dtype is None: return True - if kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(kv_cache_dtype): return flash_attn_supports_fp8() return kv_cache_dtype in ["auto", "float16", "bfloat16"] @@ -430,7 +430,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal ): cache_dtype = self.cache_config.cache_dtype - if cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(cache_dtype): qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( cache_dtype ) @@ -726,7 +726,7 @@ class FlashAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): # queries are quantized in the attention layer dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( self.kv_cache_dtype @@ -978,7 +978,7 @@ class FlashAttentionImpl(AttentionImpl): ) # For encoder attention, process FP8 quantization if needed - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "quantization is not supported for encoder attention" ) diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py index cc7ab473f..2a8ab86af 100644 --- a/vllm/v1/attention/backends/flash_attn_diffkv.py +++ b/vllm/v1/attention/backends/flash_attn_diffkv.py @@ -4,6 +4,7 @@ import torch +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( @@ -191,7 +192,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): key_cache = kv_cache[..., : self.head_size] value_cache = kv_cache[..., self.head_size :] - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): # queries are quantized in the attention layer dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( self.kv_cache_dtype diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 5b6c198e7..e2f9f2b8c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -42,7 +42,7 @@ from vllm.utils.flashinfer import ( ) from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import is_pin_memory_available -from vllm.utils.torch_utils import is_strictly_contiguous +from vllm.utils.torch_utils import is_quantized_kv_cache, is_strictly_contiguous from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -602,7 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.page_size = self.kv_cache_spec.block_size self.cache_dtype = self.cache_config.cache_dtype - if self.cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.cache_dtype): self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.cache_dtype ) @@ -1269,7 +1269,7 @@ class FlashInferImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): return ( self.support_trtllm_attn - and self.kv_cache_dtype.startswith("fp8") + and is_quantized_kv_cache(self.kv_cache_dtype) and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) ) @@ -1317,12 +1317,12 @@ class FlashInferImpl(AttentionImpl): if self.bmm1_scale is None: self.bmm1_scale = self.scale - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float if self.bmm2_scale is None: self.bmm2_scale = 1.0 - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): self.bmm2_scale *= layer._v_scale_float prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill) @@ -1375,8 +1375,8 @@ class FlashInferImpl(AttentionImpl): # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 - if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith( - "fp8" + if self.kv_sharing_target_layer_name is None and is_quantized_kv_cache( + self.kv_cache_dtype ): torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.kv_cache_dtype @@ -1486,9 +1486,8 @@ class FlashInferImpl(AttentionImpl): assert self.o_sf_scale is None out = output[num_decode_tokens:] - if ( - attn_metadata.q_data_type != FP8_DTYPE - and self.kv_cache_dtype.startswith("fp8") + if attn_metadata.q_data_type != FP8_DTYPE and is_quantized_kv_cache( + self.kv_cache_dtype ): # TRTLLM prefill attention does not support BF16 Q # and fp8 kv cache. So to enable prefill attention diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 16874c177..e832f6bdd 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -27,14 +27,13 @@ from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, - is_quantized_kv_cache, ) from vllm.v1.kv_cache_interface import AttentionSpec diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index fd4d9ab84..b731ea75d 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -17,12 +17,12 @@ from vllm.model_executor.layers.attention.mla_attention import ( ) from vllm.platforms.interface import DeviceCapability from vllm.utils.platform_utils import num_compute_units +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionLayer, AttentionType, MultipleOf, - is_quantized_kv_cache, ) logger = init_logger(__name__) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 82d463dcd..f58d9aeb3 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -20,12 +20,12 @@ from vllm.model_executor.layers.attention.mla_attention import ( ) from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import round_up +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionLayer, AttentionType, MultipleOf, - is_quantized_kv_cache, ) from vllm.v1.attention.backends.fa_utils import ( flash_attn_supports_mla, @@ -319,7 +319,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError("FP8 FlashAttention MLA not yet supported") kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 16d01bd33..658382cb2 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -16,12 +16,12 @@ from vllm.model_executor.layers.attention.mla_attention import ( QueryLenSupport, ) from vllm.platforms.interface import DeviceCapability +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionLayer, AttentionType, MultipleOf, - is_quantized_kv_cache, ) from vllm.v1.attention.backends.utils import KVCacheLayoutType @@ -184,12 +184,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): if self.bmm1_scale is None: self.bmm1_scale = self.scale - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float if self.bmm2_scale is None: self.bmm2_scale = 1.0 - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): self.bmm2_scale *= layer._k_scale_float # Reuse pre-allocated zero-init output buffer to avoid a memset diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py index 7b5ec0d49..1eb12f72e 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -26,6 +26,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( get_mla_dims, ) from vllm.platforms.interface import DeviceCapability +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -341,11 +342,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata if self.bmm1_scale is None: self.bmm1_scale = self.scale - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float if self.bmm2_scale is None: self.bmm2_scale = 1.0 - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): self.bmm2_scale *= layer._k_scale_float o = trtllm_batch_decode_with_kv_cache_mla( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index df54b865a..2f6058d69 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ) from vllm.platforms.interface import DeviceCapability from vllm.utils.platform_utils import num_compute_units +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionLayer, @@ -128,7 +129,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None - self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") + self.is_fp8_kvcache = is_quantized_kv_cache( + vllm_config.cache_config.cache_dtype + ) num_sms = num_compute_units(self.device.index) @@ -269,7 +272,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): q = reshape_query_for_spec_decode(q, num_decodes) scheduler_metadata = attn_metadata.decode.scheduler_metadata - if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"): + if envs.VLLM_BATCH_INVARIANT and not is_quantized_kv_cache(self.kv_cache_dtype): device = q.device dtype = torch.int32 @@ -299,7 +302,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata scheduler_metadata.num_splits = num_splits - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): o, lse = flash_mla_with_kvcache_fp8( q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 7cc50ec84..816ad88a8 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.platform_utils import num_compute_units +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -571,7 +572,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]): vllm_config = get_current_vllm_config() max_tokens = vllm_config.scheduler_config.max_num_batched_tokens q_concat_shape = (max_tokens, num_heads, head_size) - if kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(kv_cache_dtype): assert kv_cache_dtype == "fp8_ds_mla", ( "FlashMLA Sparse Attention backend fp8 only supports " "fp8_ds_mla kv-cache dtype" diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 3de5be31d..6fa1bbf20 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -14,11 +14,11 @@ from vllm.model_executor.layers.attention.mla_attention import ( MLACommonMetadata, ) from vllm.platforms.interface import DeviceCapability +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionLayer, AttentionType, MultipleOf, - is_quantized_kv_cache, ) from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd diff --git a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py index 44455a700..59ec42e93 100644 --- a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention.mla_attention import ( get_mla_dims, ) +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -231,7 +232,7 @@ class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]): # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # MQA 576/512 approach for both prefill and decode - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet") # Concatenate q if it's a tuple (ql_nope, q_pe) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index d0aebf614..29351fcbf 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -16,6 +16,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import num_compute_units +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -291,7 +292,7 @@ if current_platform.is_rocm(): new_key_cache = key_cache.view_as(k_cache_template) new_value_cache = value_cache.view_as(v_cache_template) QUANT = False - if kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(kv_cache_dtype): QUANT = True grid = ( num_tokens, @@ -494,7 +495,7 @@ class AiterFlashAttentionMetadataBuilder( if ( rocm_aiter_ops.is_shuffle_kv_cache_enabled() and self.scale.numel() == 1 - and self.vllm_config.cache_config.cache_dtype.startswith("fp8") + and is_quantized_kv_cache(self.vllm_config.cache_config.cache_dtype) ): layers = get_layers_from_vllm_config(self.vllm_config, Attention) first_layer_name = [k for k in layers][0] @@ -887,7 +888,7 @@ class AiterFlashAttentionImpl(AttentionImpl): cu_seqlens_kv=swa_cu_seqlens, token_to_batch=swa_token_to_batch, seq_starts=swa_seq_starts, - dequant=self.kv_cache_dtype.startswith("fp8"), + dequant=is_quantized_kv_cache(self.kv_cache_dtype), kv_cache_layout="NHD", total_tokens=swa_total_tokens, ) @@ -982,7 +983,7 @@ class AiterFlashAttentionImpl(AttentionImpl): cu_seqlens_kv=cu_seqlens_kv[chunk_idx], token_to_batch=token_to_batch[chunk_idx], seq_starts=chunk_starts[chunk_idx], - dequant=self.kv_cache_dtype.startswith("fp8"), + dequant=is_quantized_kv_cache(self.kv_cache_dtype), kv_cache_layout="SHUFFLE" if rocm_aiter_ops.is_shuffle_kv_cache_enabled() else "NHD", @@ -1081,7 +1082,7 @@ class AiterFlashAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): key_cache = key_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype()) @@ -1370,7 +1371,7 @@ class AiterFlashAttentionImpl(AttentionImpl): # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): key_cache = key_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype()) # Reshape the input keys and values and store them in the cache. @@ -1436,7 +1437,7 @@ class AiterFlashAttentionImpl(AttentionImpl): key_cache, value_cache = kv_cache.unbind(0) flash_layout = True - is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) if is_fp8_kv_cache: key_cache = key_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype()) diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index bd7f137f9..c91f8a225 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, ) +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.rocm_attn import ( @@ -200,7 +201,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): softmax_scale = self.scale fp8_post_attn_v_rescale = False - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) # When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul). @@ -299,7 +300,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): key_cache, value_cache = kv_cache.unbind(0) flash_layout = True - is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) if is_fp8_kv_cache: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 6afb617f2..a8448c489 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) from vllm.platforms import current_platform +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -315,7 +316,7 @@ class RocmAttentionImpl(AttentionImpl): layer: The attention layer """ # For encoder attention, process FP8 quantization if needed - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "quantization is not supported for encoder attention" ) @@ -406,7 +407,7 @@ class RocmAttentionImpl(AttentionImpl): kv_cache, self.num_kv_heads, self.head_size ) - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) assert layer._q_scale_float == 1.0, ( @@ -513,7 +514,7 @@ class RocmAttentionImpl(AttentionImpl): ) flash_layout = False - is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) if is_fp8_kv_cache: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6d967b515..f9a688f65 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import next_power_of_2 +from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -472,7 +473,7 @@ class TritonAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(1) - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): if key_cache.dtype != self.fp8_dtype: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) @@ -546,7 +547,7 @@ class TritonAttentionImpl(AttentionImpl): layer: The attention layer """ # For encoder attention, process FP8 quantization if needed - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "quantization is not supported for encoder attention" ) @@ -588,7 +589,7 @@ class TritonAttentionImpl(AttentionImpl): key_cache, value_cache = kv_cache.unbind(1) # Reshape the input keys and values and store them in the cache. - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) # triton kernel does not support uint8 kv_cache @@ -623,7 +624,7 @@ class TritonAttentionImpl(AttentionImpl): key_cache, value_cache = kv_cache.unbind(1) flash_layout = True - is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) if is_fp8_kv_cache: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) diff --git a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py index c5c9a9c96..f98b50f8f 100644 --- a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py @@ -5,6 +5,7 @@ import torch from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import is_quantized_kv_cache @triton.jit @@ -145,16 +146,18 @@ def triton_reshape_and_cache_flash( block_stride = key_cache.stride()[0] page_stride = key_cache.stride()[1] - assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( + assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), ( f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." ) kv_cache_torch_dtype = ( current_platform.fp8_dtype() - if kv_cache_dtype.startswith("fp8") + if is_quantized_kv_cache(kv_cache_dtype) else key_cache.dtype ) - if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): + if key_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache( + kv_cache_dtype + ): # to avoid erounous implicit cast in triton kernel (tl.store to uint8) # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) key_cache = key_cache.view(kv_cache_torch_dtype) @@ -164,7 +167,7 @@ def triton_reshape_and_cache_flash( "uint8 is not supported by triton reshape_and_cache_flash" ) - FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") + FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype) assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, @@ -323,16 +326,16 @@ def triton_reshape_and_cache_flash_diffkv( block_stride = kv_cache.stride()[0] page_stride = kv_cache.stride()[1] - assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( + assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), ( f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." ) kv_cache_torch_dtype = ( current_platform.fp8_dtype() - if kv_cache_dtype.startswith("fp8") + if is_quantized_kv_cache(kv_cache_dtype) else kv_cache.dtype ) - if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): + if kv_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(kv_cache_dtype): # to avoid erounous implicit cast in triton kernel (tl.store to uint8) # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) kv_cache = kv_cache.view(kv_cache_torch_dtype) @@ -341,7 +344,7 @@ def triton_reshape_and_cache_flash_diffkv( "uint8 is not supported by triton reshape_and_cache_flash_diffkv" ) - FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") + FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype) assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cdbaef859..02450ef4e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -109,6 +109,7 @@ from vllm.utils.nvtx_pytorch_hooks import PytHooks from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units from vllm.utils.torch_utils import ( get_dtype_size, + is_quantized_kv_cache, kv_cache_dtype_str_to_dtype, ) from vllm.v1.attention.backend import ( @@ -896,7 +897,7 @@ class GPUModelRunner( If these are left at 0.0 (default after wake_up), all KV cache values become effectively zero, causing gibberish output. """ - if not self.cache_config.cache_dtype.startswith("fp8"): + if not is_quantized_kv_cache(self.cache_config.cache_dtype): return kv_caches = getattr(self, "kv_caches", []) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ec6bc2a71..8a6236a4a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -46,7 +46,7 @@ from vllm.tasks import SupportedTask from vllm.tracing import instrument from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling -from vllm.utils.torch_utils import set_random_seed +from vllm.utils.torch_utils import is_quantized_kv_cache, set_random_seed from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ( @@ -197,7 +197,7 @@ class Worker(WorkerBase): # especially the FP8 scaling factor. if ( (tags is None or "kv_cache" in tags) - and self.cache_config.cache_dtype.startswith("fp8") + and is_quantized_kv_cache(self.cache_config.cache_dtype) and hasattr(self.model_runner, "init_fp8_kv_scales") ): self.model_runner.init_fp8_kv_scales()