[1/N][Cleanup] Standardize on use of is_quantized_kv_cache (#38659)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-04-01 00:08:01 -04:00
committed by GitHub
parent 7b01d97a22
commit 116f4be405
28 changed files with 90 additions and 75 deletions

View File

@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import is_quantized_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -236,7 +237,7 @@ class CacheConfig:
@field_validator("cache_dtype", mode="after") @field_validator("cache_dtype", mode="after")
@classmethod @classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
if cache_dtype.startswith("fp8"): if is_quantized_kv_cache(cache_dtype):
logger.info( logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU " "Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. " "memory footprint and boosts the performance. "

View File

@@ -241,6 +241,7 @@ from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
@@ -342,7 +343,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# Automatically convert fp8 kv-cache format to "fp8_ds_mla" # Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if ( if (
self.attn_backend.get_name() == "FLASHMLA_SPARSE" self.attn_backend.get_name() == "FLASHMLA_SPARSE"
and kv_cache_dtype.startswith("fp8") and is_quantized_kv_cache(kv_cache_dtype)
and kv_cache_dtype != "fp8_ds_mla" and kv_cache_dtype != "fp8_ds_mla"
): ):
assert cache_config is not None assert cache_config is not None
@@ -356,7 +357,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if ( if (
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE" self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
and kv_cache_dtype.startswith("fp8") and is_quantized_kv_cache(kv_cache_dtype)
): ):
logger.info_once( logger.info_once(
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla " "Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
@@ -571,7 +572,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if self.impl.dcp_world_size == -1: if self.impl.dcp_world_size == -1:
self.impl.dcp_world_size = get_dcp_group().world_size self.impl.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8") fp8_attention = is_quantized_kv_cache(self.kv_cache_dtype)
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
@@ -1434,7 +1435,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
is enabled, else model dtype. is enabled, else model dtype.
""" """
use_fp8 = ( use_fp8 = (
vllm_config.cache_config.cache_dtype.startswith("fp8") is_quantized_kv_cache(vllm_config.cache_config.cache_dtype)
and vllm_config.attention_config.use_prefill_query_quantization and vllm_config.attention_config.use_prefill_query_quantization
and backend_supports_prefill_query_quantization() and backend_supports_prefill_query_quantization()
) )

View File

@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backend import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@@ -23,14 +23,13 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
) )
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import maybe_prefix
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype from vllm.utils.torch_utils import is_quantized_kv_cache, kv_cache_dtype_str_to_dtype
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
is_quantized_kv_cache,
) )
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,

View File

@@ -16,7 +16,7 @@ import torch
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import CpuArchEnum, Platform, PlatformEnum from .interface import CpuArchEnum, Platform, PlatformEnum
@@ -183,7 +183,7 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache." "backend is not compatible with FP8 KV cache."
) )
if cache_config.cache_dtype.startswith("fp8"): if is_quantized_kv_cache(cache_config.cache_dtype):
logger.warning( logger.warning(
"CPU backend doesn't support KV cache quantization fallback to auto." "CPU backend doesn't support KV cache quantization fallback to auto."
) )

View File

@@ -23,6 +23,7 @@ import vllm._C_stable_libtorch # noqa
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
@@ -87,7 +88,7 @@ def _get_backend_priorities(
# Sparse MLA backend priorities # Sparse MLA backend priorities
# See https://github.com/vllm-project/vllm/issues/35807 for # See https://github.com/vllm-project/vllm/issues/35807 for
# benchmark results # benchmark results
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): if kv_cache_dtype is not None and is_quantized_kv_cache(kv_cache_dtype):
# Prefer FlashInfer for fp8 kv cache # Prefer FlashInfer for fp8 kv cache
sparse_backends = [ sparse_backends = [
AttentionBackendEnum.FLASHINFER_MLA_SPARSE, AttentionBackendEnum.FLASHINFER_MLA_SPARSE,

View File

@@ -61,6 +61,10 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
T = TypeVar("T") T = TypeVar("T")
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")
def is_strictly_contiguous(t: torch.Tensor) -> bool: def is_strictly_contiguous(t: torch.Tensor) -> bool:
""" """
Check if tensor is contiguous AND has no degenerate strides. Check if tensor is contiguous AND has no degenerate strides.

View File

@@ -954,10 +954,6 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
) )
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")
def subclass_attention_backend( def subclass_attention_backend(
name_prefix: str, name_prefix: str,
attention_backend_cls: type[AttentionBackend], attention_backend_cls: type[AttentionBackend],

View File

@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
@@ -16,7 +17,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,

View File

@@ -10,12 +10,12 @@ import numpy as np
import torch import torch
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_fp8, flash_attn_supports_fp8,
@@ -177,7 +177,7 @@ class FlashAttentionBackend(AttentionBackend):
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None: if kv_cache_dtype is None:
return True return True
if kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(kv_cache_dtype):
return flash_attn_supports_fp8() return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto", "float16", "bfloat16"] return kv_cache_dtype in ["auto", "float16", "bfloat16"]
@@ -430,7 +430,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
): ):
cache_dtype = self.cache_config.cache_dtype cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"): if is_quantized_kv_cache(cache_dtype):
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
cache_dtype cache_dtype
) )
@@ -726,7 +726,7 @@ class FlashAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer # queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype self.kv_cache_dtype
@@ -978,7 +978,7 @@ class FlashAttentionImpl(AttentionImpl):
) )
# For encoder attention, process FP8 quantization if needed # For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"quantization is not supported for encoder attention" "quantization is not supported for encoder attention"
) )

View File

@@ -4,6 +4,7 @@
import torch import torch
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
@@ -191,7 +192,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
key_cache = kv_cache[..., : self.head_size] key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :] value_cache = kv_cache[..., self.head_size :]
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer # queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype self.kv_cache_dtype

View File

@@ -42,7 +42,7 @@ from vllm.utils.flashinfer import (
) )
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import is_strictly_contiguous from vllm.utils.torch_utils import is_quantized_kv_cache, is_strictly_contiguous
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
@@ -602,7 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.page_size = self.kv_cache_spec.block_size self.page_size = self.kv_cache_spec.block_size
self.cache_dtype = self.cache_config.cache_dtype self.cache_dtype = self.cache_config.cache_dtype
if self.cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.cache_dtype):
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.cache_dtype self.cache_dtype
) )
@@ -1269,7 +1269,7 @@ class FlashInferImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return ( return (
self.support_trtllm_attn self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8") and is_quantized_kv_cache(self.kv_cache_dtype)
and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
) )
@@ -1317,12 +1317,12 @@ class FlashInferImpl(AttentionImpl):
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = 1.0 self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._v_scale_float self.bmm2_scale *= layer._v_scale_float
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill) prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
@@ -1375,8 +1375,8 @@ class FlashInferImpl(AttentionImpl):
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8 # to process the cache when the kv_cache_dtype is fp8
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith( if self.kv_sharing_target_layer_name is None and is_quantized_kv_cache(
"fp8" self.kv_cache_dtype
): ):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype self.kv_cache_dtype
@@ -1486,9 +1486,8 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None assert self.o_sf_scale is None
out = output[num_decode_tokens:] out = output[num_decode_tokens:]
if ( if attn_metadata.q_data_type != FP8_DTYPE and is_quantized_kv_cache(
attn_metadata.q_data_type != FP8_DTYPE self.kv_cache_dtype
and self.kv_cache_dtype.startswith("fp8")
): ):
# TRTLLM prefill attention does not support BF16 Q # TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention # and fp8 kv cache. So to enable prefill attention

View File

@@ -27,14 +27,13 @@ from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
is_quantized_kv_cache,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -17,12 +17,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@@ -20,12 +20,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla, flash_attn_supports_mla,
@@ -319,7 +319,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
) )
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError("FP8 FlashAttention MLA not yet supported") raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]

View File

@@ -16,12 +16,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport, QueryLenSupport,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
@@ -184,12 +184,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = 1.0 self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._k_scale_float self.bmm2_scale *= layer._k_scale_float
# Reuse pre-allocated zero-init output buffer to avoid a memset # Reuse pre-allocated zero-init output buffer to avoid a memset

View File

@@ -26,6 +26,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims, get_mla_dims,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
@@ -341,11 +342,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = 1.0 self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._k_scale_float self.bmm2_scale *= layer._k_scale_float
o = trtllm_batch_decode_with_kv_cache_mla( o = trtllm_batch_decode_with_kv_cache_mla(

View File

@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
@@ -128,7 +129,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_tile_scheduler_metadata = None self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") self.is_fp8_kvcache = is_quantized_kv_cache(
vllm_config.cache_config.cache_dtype
)
num_sms = num_compute_units(self.device.index) num_sms = num_compute_units(self.device.index)
@@ -269,7 +272,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = reshape_query_for_spec_decode(q, num_decodes) q = reshape_query_for_spec_decode(q, num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata scheduler_metadata = attn_metadata.decode.scheduler_metadata
if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"): if envs.VLLM_BATCH_INVARIANT and not is_quantized_kv_cache(self.kv_cache_dtype):
device = q.device device = q.device
dtype = torch.int32 dtype = torch.int32
@@ -299,7 +302,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits scheduler_metadata.num_splits = num_splits
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
o, lse = flash_mla_with_kvcache_fp8( o, lse = flash_mla_with_kvcache_fp8(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
@@ -571,7 +572,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
q_concat_shape = (max_tokens, num_heads, head_size) q_concat_shape = (max_tokens, num_heads, head_size)
if kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(kv_cache_dtype):
assert kv_cache_dtype == "fp8_ds_mla", ( assert kv_cache_dtype == "fp8_ds_mla", (
"FlashMLA Sparse Attention backend fp8 only supports " "FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype" "fp8_ds_mla kv-cache dtype"

View File

@@ -14,11 +14,11 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadata, MLACommonMetadata,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd

View File

@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import ( from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims, get_mla_dims,
) )
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
@@ -231,7 +232,7 @@ class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]):
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode # MQA 576/512 approach for both prefill and decode
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet") raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet")
# Concatenate q if it's a tuple (ql_nope, q_pe) # Concatenate q if it's a tuple (ql_nope, q_pe)

View File

@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
@@ -291,7 +292,7 @@ if current_platform.is_rocm():
new_key_cache = key_cache.view_as(k_cache_template) new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template) new_value_cache = value_cache.view_as(v_cache_template)
QUANT = False QUANT = False
if kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(kv_cache_dtype):
QUANT = True QUANT = True
grid = ( grid = (
num_tokens, num_tokens,
@@ -494,7 +495,7 @@ class AiterFlashAttentionMetadataBuilder(
if ( if (
rocm_aiter_ops.is_shuffle_kv_cache_enabled() rocm_aiter_ops.is_shuffle_kv_cache_enabled()
and self.scale.numel() == 1 and self.scale.numel() == 1
and self.vllm_config.cache_config.cache_dtype.startswith("fp8") and is_quantized_kv_cache(self.vllm_config.cache_config.cache_dtype)
): ):
layers = get_layers_from_vllm_config(self.vllm_config, Attention) layers = get_layers_from_vllm_config(self.vllm_config, Attention)
first_layer_name = [k for k in layers][0] first_layer_name = [k for k in layers][0]
@@ -887,7 +888,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv=swa_cu_seqlens, cu_seqlens_kv=swa_cu_seqlens,
token_to_batch=swa_token_to_batch, token_to_batch=swa_token_to_batch,
seq_starts=swa_seq_starts, seq_starts=swa_seq_starts,
dequant=self.kv_cache_dtype.startswith("fp8"), dequant=is_quantized_kv_cache(self.kv_cache_dtype),
kv_cache_layout="NHD", kv_cache_layout="NHD",
total_tokens=swa_total_tokens, total_tokens=swa_total_tokens,
) )
@@ -982,7 +983,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv=cu_seqlens_kv[chunk_idx], cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
token_to_batch=token_to_batch[chunk_idx], token_to_batch=token_to_batch[chunk_idx],
seq_starts=chunk_starts[chunk_idx], seq_starts=chunk_starts[chunk_idx],
dequant=self.kv_cache_dtype.startswith("fp8"), dequant=is_quantized_kv_cache(self.kv_cache_dtype),
kv_cache_layout="SHUFFLE" kv_cache_layout="SHUFFLE"
if rocm_aiter_ops.is_shuffle_kv_cache_enabled() if rocm_aiter_ops.is_shuffle_kv_cache_enabled()
else "NHD", else "NHD",
@@ -1081,7 +1082,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(current_platform.fp8_dtype()) key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype())
@@ -1370,7 +1371,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
# key and value may be None in the case of cross attention. They are # key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached # calculated once based on the output from the encoder and then cached
# in KV cache. # in KV cache.
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(current_platform.fp8_dtype()) key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype())
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
@@ -1436,7 +1437,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache: if is_fp8_kv_cache:
key_cache = key_cache.view(current_platform.fp8_dtype()) key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype())

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import ( from vllm.v1.attention.backends.rocm_attn import (
@@ -200,7 +201,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
softmax_scale = self.scale softmax_scale = self.scale
fp8_post_attn_v_rescale = False fp8_post_attn_v_rescale = False
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
# When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul). # When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
@@ -299,7 +300,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache: if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
@@ -315,7 +316,7 @@ class RocmAttentionImpl(AttentionImpl):
layer: The attention layer layer: The attention layer
""" """
# For encoder attention, process FP8 quantization if needed # For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"quantization is not supported for encoder attention" "quantization is not supported for encoder attention"
) )
@@ -406,7 +407,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size kv_cache, self.num_kv_heads, self.head_size
) )
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, ( assert layer._q_scale_float == 1.0, (
@@ -513,7 +514,7 @@ class RocmAttentionImpl(AttentionImpl):
) )
flash_layout = False flash_layout = False
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache: if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)

View File

@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
@@ -472,7 +473,7 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1) key_cache, value_cache = kv_cache.unbind(1)
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
if key_cache.dtype != self.fp8_dtype: if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
@@ -546,7 +547,7 @@ class TritonAttentionImpl(AttentionImpl):
layer: The attention layer layer: The attention layer
""" """
# For encoder attention, process FP8 quantization if needed # For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"quantization is not supported for encoder attention" "quantization is not supported for encoder attention"
) )
@@ -588,7 +589,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(1) key_cache, value_cache = kv_cache.unbind(1)
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache # triton kernel does not support uint8 kv_cache
@@ -623,7 +624,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(1) key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache: if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)

View File

@@ -5,6 +5,7 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_quantized_kv_cache
@triton.jit @triton.jit
@@ -145,16 +146,18 @@ def triton_reshape_and_cache_flash(
block_stride = key_cache.stride()[0] block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1] page_stride = key_cache.stride()[1]
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
) )
kv_cache_torch_dtype = ( kv_cache_torch_dtype = (
current_platform.fp8_dtype() current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8") if is_quantized_kv_cache(kv_cache_dtype)
else key_cache.dtype else key_cache.dtype
) )
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): if key_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(
kv_cache_dtype
):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8) # to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype) key_cache = key_cache.view(kv_cache_torch_dtype)
@@ -164,7 +167,7 @@ def triton_reshape_and_cache_flash(
"uint8 is not supported by triton reshape_and_cache_flash" "uint8 is not supported by triton reshape_and_cache_flash"
) )
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e5m2, torch.float8_e5m2,
@@ -323,16 +326,16 @@ def triton_reshape_and_cache_flash_diffkv(
block_stride = kv_cache.stride()[0] block_stride = kv_cache.stride()[0]
page_stride = kv_cache.stride()[1] page_stride = kv_cache.stride()[1]
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
) )
kv_cache_torch_dtype = ( kv_cache_torch_dtype = (
current_platform.fp8_dtype() current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8") if is_quantized_kv_cache(kv_cache_dtype)
else kv_cache.dtype else kv_cache.dtype
) )
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): if kv_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(kv_cache_dtype):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8) # to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
kv_cache = kv_cache.view(kv_cache_torch_dtype) kv_cache = kv_cache.view(kv_cache_torch_dtype)
@@ -341,7 +344,7 @@ def triton_reshape_and_cache_flash_diffkv(
"uint8 is not supported by triton reshape_and_cache_flash_diffkv" "uint8 is not supported by triton reshape_and_cache_flash_diffkv"
) )
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e5m2, torch.float8_e5m2,

View File

@@ -109,6 +109,7 @@ from vllm.utils.nvtx_pytorch_hooks import PytHooks
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
get_dtype_size, get_dtype_size,
is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
@@ -896,7 +897,7 @@ class GPUModelRunner(
If these are left at 0.0 (default after wake_up), all KV cache values If these are left at 0.0 (default after wake_up), all KV cache values
become effectively zero, causing gibberish output. become effectively zero, causing gibberish output.
""" """
if not self.cache_config.cache_dtype.startswith("fp8"): if not is_quantized_kv_cache(self.cache_config.cache_dtype):
return return
kv_caches = getattr(self, "kv_caches", []) kv_caches = getattr(self, "kv_caches", [])

View File

@@ -46,7 +46,7 @@ from vllm.tasks import SupportedTask
from vllm.tracing import instrument from vllm.tracing import instrument
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import is_quantized_kv_cache, set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ( from vllm.v1.outputs import (
@@ -197,7 +197,7 @@ class Worker(WorkerBase):
# especially the FP8 scaling factor. # especially the FP8 scaling factor.
if ( if (
(tags is None or "kv_cache" in tags) (tags is None or "kv_cache" in tags)
and self.cache_config.cache_dtype.startswith("fp8") and is_quantized_kv_cache(self.cache_config.cache_dtype)
and hasattr(self.model_runner, "init_fp8_kv_scales") and hasattr(self.model_runner, "init_fp8_kv_scales")
): ):
self.model_runner.init_fp8_kv_scales() self.model_runner.init_fp8_kv_scales()