[1/N][Cleanup] Standardize on use of is_quantized_kv_cache (#38659)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
|
|||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -236,7 +237,7 @@ class CacheConfig:
|
|||||||
@field_validator("cache_dtype", mode="after")
|
@field_validator("cache_dtype", mode="after")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
|
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
|
||||||
if cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(cache_dtype):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||||
"memory footprint and boosts the performance. "
|
"memory footprint and boosts the performance. "
|
||||||
|
|||||||
@@ -241,6 +241,7 @@ from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
|
|||||||
from vllm.utils.math_utils import cdiv, round_down
|
from vllm.utils.math_utils import cdiv, round_down
|
||||||
from vllm.utils.torch_utils import (
|
from vllm.utils.torch_utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
|
is_quantized_kv_cache,
|
||||||
kv_cache_dtype_str_to_dtype,
|
kv_cache_dtype_str_to_dtype,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
@@ -342,7 +343,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
|
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
|
||||||
if (
|
if (
|
||||||
self.attn_backend.get_name() == "FLASHMLA_SPARSE"
|
self.attn_backend.get_name() == "FLASHMLA_SPARSE"
|
||||||
and kv_cache_dtype.startswith("fp8")
|
and is_quantized_kv_cache(kv_cache_dtype)
|
||||||
and kv_cache_dtype != "fp8_ds_mla"
|
and kv_cache_dtype != "fp8_ds_mla"
|
||||||
):
|
):
|
||||||
assert cache_config is not None
|
assert cache_config is not None
|
||||||
@@ -356,7 +357,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
|
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
|
||||||
and kv_cache_dtype.startswith("fp8")
|
and is_quantized_kv_cache(kv_cache_dtype)
|
||||||
):
|
):
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
|
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
|
||||||
@@ -571,7 +572,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
if self.impl.dcp_world_size == -1:
|
if self.impl.dcp_world_size == -1:
|
||||||
self.impl.dcp_world_size = get_dcp_group().world_size
|
self.impl.dcp_world_size = get_dcp_group().world_size
|
||||||
|
|
||||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
fp8_attention = is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
|
|
||||||
num_actual_toks = attn_metadata.num_actual_tokens
|
num_actual_toks = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
@@ -1434,7 +1435,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
is enabled, else model dtype.
|
is enabled, else model dtype.
|
||||||
"""
|
"""
|
||||||
use_fp8 = (
|
use_fp8 = (
|
||||||
vllm_config.cache_config.cache_dtype.startswith("fp8")
|
is_quantized_kv_cache(vllm_config.cache_config.cache_dtype)
|
||||||
and vllm_config.attention_config.use_prefill_query_quantization
|
and vllm_config.attention_config.use_prefill_query_quantization
|
||||||
and backend_supports_prefill_query_quantization()
|
and backend_supports_prefill_query_quantization()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backend import is_quantized_kv_cache
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -23,14 +23,13 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
|
from vllm.utils.torch_utils import is_quantized_kv_cache, kv_cache_dtype_str_to_dtype
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
is_quantized_kv_cache,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
AttentionSpec,
|
AttentionSpec,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import torch
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backend import is_quantized_kv_cache
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||||
@@ -183,7 +183,7 @@ class CpuPlatform(Platform):
|
|||||||
"backend is not compatible with FP8 KV cache."
|
"backend is not compatible with FP8 KV cache."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cache_config.cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(cache_config.cache_dtype):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"CPU backend doesn't support KV cache quantization fallback to auto."
|
"CPU backend doesn't support KV cache quantization fallback to auto."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import vllm._C_stable_libtorch # noqa
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.import_utils import import_pynvml
|
from vllm.utils.import_utils import import_pynvml
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||||
@@ -87,7 +88,7 @@ def _get_backend_priorities(
|
|||||||
# Sparse MLA backend priorities
|
# Sparse MLA backend priorities
|
||||||
# See https://github.com/vllm-project/vllm/issues/35807 for
|
# See https://github.com/vllm-project/vllm/issues/35807 for
|
||||||
# benchmark results
|
# benchmark results
|
||||||
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
if kv_cache_dtype is not None and is_quantized_kv_cache(kv_cache_dtype):
|
||||||
# Prefer FlashInfer for fp8 kv cache
|
# Prefer FlashInfer for fp8 kv cache
|
||||||
sparse_backends = [
|
sparse_backends = [
|
||||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||||
|
|||||||
@@ -61,6 +61,10 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||||
|
return kv_cache_dtype.startswith("fp8")
|
||||||
|
|
||||||
|
|
||||||
def is_strictly_contiguous(t: torch.Tensor) -> bool:
|
def is_strictly_contiguous(t: torch.Tensor) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if tensor is contiguous AND has no degenerate strides.
|
Check if tensor is contiguous AND has no degenerate strides.
|
||||||
|
|||||||
@@ -954,10 +954,6 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
|
||||||
return kv_cache_dtype.startswith("fp8")
|
|
||||||
|
|
||||||
|
|
||||||
def subclass_attention_backend(
|
def subclass_attention_backend(
|
||||||
name_prefix: str,
|
name_prefix: str,
|
||||||
attention_backend_cls: type[AttentionBackend],
|
attention_backend_cls: type[AttentionBackend],
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
@@ -16,7 +17,6 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
is_quantized_kv_cache,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
|
|||||||
@@ -10,12 +10,12 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.attention import Attention
|
from vllm.model_executor.layers.attention import Attention
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
is_quantized_kv_cache,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.fa_utils import (
|
from vllm.v1.attention.backends.fa_utils import (
|
||||||
flash_attn_supports_fp8,
|
flash_attn_supports_fp8,
|
||||||
@@ -177,7 +177,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
|
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
|
||||||
if kv_cache_dtype is None:
|
if kv_cache_dtype is None:
|
||||||
return True
|
return True
|
||||||
if kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(kv_cache_dtype):
|
||||||
return flash_attn_supports_fp8()
|
return flash_attn_supports_fp8()
|
||||||
return kv_cache_dtype in ["auto", "float16", "bfloat16"]
|
return kv_cache_dtype in ["auto", "float16", "bfloat16"]
|
||||||
|
|
||||||
@@ -430,7 +430,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
||||||
):
|
):
|
||||||
cache_dtype = self.cache_config.cache_dtype
|
cache_dtype = self.cache_config.cache_dtype
|
||||||
if cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(cache_dtype):
|
||||||
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||||
cache_dtype
|
cache_dtype
|
||||||
)
|
)
|
||||||
@@ -726,7 +726,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# For decoder and cross-attention, use KV cache as before
|
# For decoder and cross-attention, use KV cache as before
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
# queries are quantized in the attention layer
|
# queries are quantized in the attention layer
|
||||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||||
self.kv_cache_dtype
|
self.kv_cache_dtype
|
||||||
@@ -978,7 +978,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# For encoder attention, process FP8 quantization if needed
|
# For encoder attention, process FP8 quantization if needed
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"quantization is not supported for encoder attention"
|
"quantization is not supported for encoder attention"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import AttentionType
|
from vllm.v1.attention.backend import AttentionType
|
||||||
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
|
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
|
||||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
@@ -191,7 +192,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
|||||||
key_cache = kv_cache[..., : self.head_size]
|
key_cache = kv_cache[..., : self.head_size]
|
||||||
value_cache = kv_cache[..., self.head_size :]
|
value_cache = kv_cache[..., self.head_size :]
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
# queries are quantized in the attention layer
|
# queries are quantized in the attention layer
|
||||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||||
self.kv_cache_dtype
|
self.kv_cache_dtype
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from vllm.utils.flashinfer import (
|
|||||||
)
|
)
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.utils.torch_utils import is_strictly_contiguous
|
from vllm.utils.torch_utils import is_quantized_kv_cache, is_strictly_contiguous
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@@ -602,7 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.page_size = self.kv_cache_spec.block_size
|
self.page_size = self.kv_cache_spec.block_size
|
||||||
|
|
||||||
self.cache_dtype = self.cache_config.cache_dtype
|
self.cache_dtype = self.cache_config.cache_dtype
|
||||||
if self.cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.cache_dtype):
|
||||||
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
self.cache_dtype
|
self.cache_dtype
|
||||||
)
|
)
|
||||||
@@ -1269,7 +1269,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
return (
|
return (
|
||||||
self.support_trtllm_attn
|
self.support_trtllm_attn
|
||||||
and self.kv_cache_dtype.startswith("fp8")
|
and is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
|
and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1317,12 +1317,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
|
|
||||||
if self.bmm1_scale is None:
|
if self.bmm1_scale is None:
|
||||||
self.bmm1_scale = self.scale
|
self.bmm1_scale = self.scale
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
||||||
|
|
||||||
if self.bmm2_scale is None:
|
if self.bmm2_scale is None:
|
||||||
self.bmm2_scale = 1.0
|
self.bmm2_scale = 1.0
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
self.bmm2_scale *= layer._v_scale_float
|
self.bmm2_scale *= layer._v_scale_float
|
||||||
|
|
||||||
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||||
@@ -1375,8 +1375,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
|
|
||||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||||
# to process the cache when the kv_cache_dtype is fp8
|
# to process the cache when the kv_cache_dtype is fp8
|
||||||
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
|
if self.kv_sharing_target_layer_name is None and is_quantized_kv_cache(
|
||||||
"fp8"
|
self.kv_cache_dtype
|
||||||
):
|
):
|
||||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||||
self.kv_cache_dtype
|
self.kv_cache_dtype
|
||||||
@@ -1486,9 +1486,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert self.o_sf_scale is None
|
assert self.o_sf_scale is None
|
||||||
out = output[num_decode_tokens:]
|
out = output[num_decode_tokens:]
|
||||||
|
|
||||||
if (
|
if attn_metadata.q_data_type != FP8_DTYPE and is_quantized_kv_cache(
|
||||||
attn_metadata.q_data_type != FP8_DTYPE
|
self.kv_cache_dtype
|
||||||
and self.kv_cache_dtype.startswith("fp8")
|
|
||||||
):
|
):
|
||||||
# TRTLLM prefill attention does not support BF16 Q
|
# TRTLLM prefill attention does not support BF16 Q
|
||||||
# and fp8 kv cache. So to enable prefill attention
|
# and fp8 kv cache. So to enable prefill attention
|
||||||
|
|||||||
@@ -27,14 +27,13 @@ from vllm.config.cache import CacheDType
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
is_quantized_kv_cache,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
|||||||
@@ -17,12 +17,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
from vllm.utils.platform_utils import num_compute_units
|
from vllm.utils.platform_utils import num_compute_units
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
is_quantized_kv_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@@ -20,12 +20,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
from vllm.utils.math_utils import round_up
|
from vllm.utils.math_utils import round_up
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
is_quantized_kv_cache,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.fa_utils import (
|
from vllm.v1.attention.backends.fa_utils import (
|
||||||
flash_attn_supports_mla,
|
flash_attn_supports_mla,
|
||||||
@@ -319,7 +319,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
|||||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
|
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
|
||||||
|
|
||||||
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
||||||
|
|||||||
@@ -16,12 +16,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
QueryLenSupport,
|
QueryLenSupport,
|
||||||
)
|
)
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
is_quantized_kv_cache,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||||
|
|
||||||
@@ -184,12 +184,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
if self.bmm1_scale is None:
|
if self.bmm1_scale is None:
|
||||||
self.bmm1_scale = self.scale
|
self.bmm1_scale = self.scale
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
||||||
|
|
||||||
if self.bmm2_scale is None:
|
if self.bmm2_scale is None:
|
||||||
self.bmm2_scale = 1.0
|
self.bmm2_scale = 1.0
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
self.bmm2_scale *= layer._k_scale_float
|
self.bmm2_scale *= layer._k_scale_float
|
||||||
|
|
||||||
# Reuse pre-allocated zero-init output buffer to avoid a memset
|
# Reuse pre-allocated zero-init output buffer to avoid a memset
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
get_mla_dims,
|
get_mla_dims,
|
||||||
)
|
)
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@@ -341,11 +342,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
|
|||||||
|
|
||||||
if self.bmm1_scale is None:
|
if self.bmm1_scale is None:
|
||||||
self.bmm1_scale = self.scale
|
self.bmm1_scale = self.scale
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
||||||
if self.bmm2_scale is None:
|
if self.bmm2_scale is None:
|
||||||
self.bmm2_scale = 1.0
|
self.bmm2_scale = 1.0
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
self.bmm2_scale *= layer._k_scale_float
|
self.bmm2_scale *= layer._k_scale_float
|
||||||
|
|
||||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
from vllm.utils.platform_utils import num_compute_units
|
from vllm.utils.platform_utils import num_compute_units
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
@@ -128,7 +129,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
|
|
||||||
self.cg_buf_tile_scheduler_metadata = None
|
self.cg_buf_tile_scheduler_metadata = None
|
||||||
self.cg_buf_num_splits = None
|
self.cg_buf_num_splits = None
|
||||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
self.is_fp8_kvcache = is_quantized_kv_cache(
|
||||||
|
vllm_config.cache_config.cache_dtype
|
||||||
|
)
|
||||||
|
|
||||||
num_sms = num_compute_units(self.device.index)
|
num_sms = num_compute_units(self.device.index)
|
||||||
|
|
||||||
@@ -269,7 +272,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||||
|
|
||||||
scheduler_metadata = attn_metadata.decode.scheduler_metadata
|
scheduler_metadata = attn_metadata.decode.scheduler_metadata
|
||||||
if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"):
|
if envs.VLLM_BATCH_INVARIANT and not is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
device = q.device
|
device = q.device
|
||||||
dtype = torch.int32
|
dtype = torch.int32
|
||||||
|
|
||||||
@@ -299,7 +302,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
||||||
scheduler_metadata.num_splits = num_splits
|
scheduler_metadata.num_splits = num_splits
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
o, lse = flash_mla_with_kvcache_fp8(
|
o, lse = flash_mla_with_kvcache_fp8(
|
||||||
q=q,
|
q=q,
|
||||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
from vllm.utils.platform_utils import num_compute_units
|
from vllm.utils.platform_utils import num_compute_units
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@@ -571,7 +572,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
|||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
q_concat_shape = (max_tokens, num_heads, head_size)
|
q_concat_shape = (max_tokens, num_heads, head_size)
|
||||||
if kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(kv_cache_dtype):
|
||||||
assert kv_cache_dtype == "fp8_ds_mla", (
|
assert kv_cache_dtype == "fp8_ds_mla", (
|
||||||
"FlashMLA Sparse Attention backend fp8 only supports "
|
"FlashMLA Sparse Attention backend fp8 only supports "
|
||||||
"fp8_ds_mla kv-cache dtype"
|
"fp8_ds_mla kv-cache dtype"
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
)
|
)
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
is_quantized_kv_cache,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.attention.mla_attention import (
|
from vllm.model_executor.layers.attention.mla_attention import (
|
||||||
get_mla_dims,
|
get_mla_dims,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@@ -231,7 +232,7 @@ class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]):
|
|||||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||||
# MQA 576/512 approach for both prefill and decode
|
# MQA 576/512 approach for both prefill and decode
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet")
|
raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet")
|
||||||
|
|
||||||
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.platform_utils import num_compute_units
|
from vllm.utils.platform_utils import num_compute_units
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@@ -291,7 +292,7 @@ if current_platform.is_rocm():
|
|||||||
new_key_cache = key_cache.view_as(k_cache_template)
|
new_key_cache = key_cache.view_as(k_cache_template)
|
||||||
new_value_cache = value_cache.view_as(v_cache_template)
|
new_value_cache = value_cache.view_as(v_cache_template)
|
||||||
QUANT = False
|
QUANT = False
|
||||||
if kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(kv_cache_dtype):
|
||||||
QUANT = True
|
QUANT = True
|
||||||
grid = (
|
grid = (
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -494,7 +495,7 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
if (
|
if (
|
||||||
rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
||||||
and self.scale.numel() == 1
|
and self.scale.numel() == 1
|
||||||
and self.vllm_config.cache_config.cache_dtype.startswith("fp8")
|
and is_quantized_kv_cache(self.vllm_config.cache_config.cache_dtype)
|
||||||
):
|
):
|
||||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||||
first_layer_name = [k for k in layers][0]
|
first_layer_name = [k for k in layers][0]
|
||||||
@@ -887,7 +888,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
cu_seqlens_kv=swa_cu_seqlens,
|
cu_seqlens_kv=swa_cu_seqlens,
|
||||||
token_to_batch=swa_token_to_batch,
|
token_to_batch=swa_token_to_batch,
|
||||||
seq_starts=swa_seq_starts,
|
seq_starts=swa_seq_starts,
|
||||||
dequant=self.kv_cache_dtype.startswith("fp8"),
|
dequant=is_quantized_kv_cache(self.kv_cache_dtype),
|
||||||
kv_cache_layout="NHD",
|
kv_cache_layout="NHD",
|
||||||
total_tokens=swa_total_tokens,
|
total_tokens=swa_total_tokens,
|
||||||
)
|
)
|
||||||
@@ -982,7 +983,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
|
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
|
||||||
token_to_batch=token_to_batch[chunk_idx],
|
token_to_batch=token_to_batch[chunk_idx],
|
||||||
seq_starts=chunk_starts[chunk_idx],
|
seq_starts=chunk_starts[chunk_idx],
|
||||||
dequant=self.kv_cache_dtype.startswith("fp8"),
|
dequant=is_quantized_kv_cache(self.kv_cache_dtype),
|
||||||
kv_cache_layout="SHUFFLE"
|
kv_cache_layout="SHUFFLE"
|
||||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
if rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
||||||
else "NHD",
|
else "NHD",
|
||||||
@@ -1081,7 +1082,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
@@ -1370,7 +1371,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
# key and value may be None in the case of cross attention. They are
|
# key and value may be None in the case of cross attention. They are
|
||||||
# calculated once based on the output from the encoder and then cached
|
# calculated once based on the output from the encoder and then cached
|
||||||
# in KV cache.
|
# in KV cache.
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
@@ -1436,7 +1437,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
flash_layout = True
|
flash_layout = True
|
||||||
|
|
||||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
if is_fp8_kv_cache:
|
if is_fp8_kv_cache:
|
||||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
QuantKey,
|
QuantKey,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
|
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.attention.backends.rocm_attn import (
|
from vllm.v1.attention.backends.rocm_attn import (
|
||||||
@@ -200,7 +201,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
|||||||
|
|
||||||
softmax_scale = self.scale
|
softmax_scale = self.scale
|
||||||
fp8_post_attn_v_rescale = False
|
fp8_post_attn_v_rescale = False
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
# When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
|
# When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
|
||||||
@@ -299,7 +300,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
|||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
flash_layout = True
|
flash_layout = True
|
||||||
|
|
||||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
if is_fp8_kv_cache:
|
if is_fp8_kv_cache:
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@@ -315,7 +316,7 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
layer: The attention layer
|
layer: The attention layer
|
||||||
"""
|
"""
|
||||||
# For encoder attention, process FP8 quantization if needed
|
# For encoder attention, process FP8 quantization if needed
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"quantization is not supported for encoder attention"
|
"quantization is not supported for encoder attention"
|
||||||
)
|
)
|
||||||
@@ -406,7 +407,7 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
kv_cache, self.num_kv_heads, self.head_size
|
kv_cache, self.num_kv_heads, self.head_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
assert layer._q_scale_float == 1.0, (
|
assert layer._q_scale_float == 1.0, (
|
||||||
@@ -513,7 +514,7 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
flash_layout = False
|
flash_layout = False
|
||||||
|
|
||||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
if is_fp8_kv_cache:
|
if is_fp8_kv_cache:
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
from vllm.utils.math_utils import next_power_of_2
|
from vllm.utils.math_utils import next_power_of_2
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@@ -472,7 +473,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
# For decoder and cross-attention, use KV cache as before
|
# For decoder and cross-attention, use KV cache as before
|
||||||
key_cache, value_cache = kv_cache.unbind(1)
|
key_cache, value_cache = kv_cache.unbind(1)
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
if key_cache.dtype != self.fp8_dtype:
|
if key_cache.dtype != self.fp8_dtype:
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
@@ -546,7 +547,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
layer: The attention layer
|
layer: The attention layer
|
||||||
"""
|
"""
|
||||||
# For encoder attention, process FP8 quantization if needed
|
# For encoder attention, process FP8 quantization if needed
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"quantization is not supported for encoder attention"
|
"quantization is not supported for encoder attention"
|
||||||
)
|
)
|
||||||
@@ -588,7 +589,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
key_cache, value_cache = kv_cache.unbind(1)
|
key_cache, value_cache = kv_cache.unbind(1)
|
||||||
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
# triton kernel does not support uint8 kv_cache
|
# triton kernel does not support uint8 kv_cache
|
||||||
@@ -623,7 +624,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
key_cache, value_cache = kv_cache.unbind(1)
|
key_cache, value_cache = kv_cache.unbind(1)
|
||||||
flash_layout = True
|
flash_layout = True
|
||||||
|
|
||||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
if is_fp8_kv_cache:
|
if is_fp8_kv_cache:
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -145,16 +146,18 @@ def triton_reshape_and_cache_flash(
|
|||||||
block_stride = key_cache.stride()[0]
|
block_stride = key_cache.stride()[0]
|
||||||
page_stride = key_cache.stride()[1]
|
page_stride = key_cache.stride()[1]
|
||||||
|
|
||||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
|
||||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||||
)
|
)
|
||||||
kv_cache_torch_dtype = (
|
kv_cache_torch_dtype = (
|
||||||
current_platform.fp8_dtype()
|
current_platform.fp8_dtype()
|
||||||
if kv_cache_dtype.startswith("fp8")
|
if is_quantized_kv_cache(kv_cache_dtype)
|
||||||
else key_cache.dtype
|
else key_cache.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
|
if key_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(
|
||||||
|
kv_cache_dtype
|
||||||
|
):
|
||||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||||
@@ -164,7 +167,7 @@ def triton_reshape_and_cache_flash(
|
|||||||
"uint8 is not supported by triton reshape_and_cache_flash"
|
"uint8 is not supported by triton reshape_and_cache_flash"
|
||||||
)
|
)
|
||||||
|
|
||||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
|
||||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
torch.float8_e5m2,
|
torch.float8_e5m2,
|
||||||
@@ -323,16 +326,16 @@ def triton_reshape_and_cache_flash_diffkv(
|
|||||||
block_stride = kv_cache.stride()[0]
|
block_stride = kv_cache.stride()[0]
|
||||||
page_stride = kv_cache.stride()[1]
|
page_stride = kv_cache.stride()[1]
|
||||||
|
|
||||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
|
||||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||||
)
|
)
|
||||||
kv_cache_torch_dtype = (
|
kv_cache_torch_dtype = (
|
||||||
current_platform.fp8_dtype()
|
current_platform.fp8_dtype()
|
||||||
if kv_cache_dtype.startswith("fp8")
|
if is_quantized_kv_cache(kv_cache_dtype)
|
||||||
else kv_cache.dtype
|
else kv_cache.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
|
if kv_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(kv_cache_dtype):
|
||||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||||
kv_cache = kv_cache.view(kv_cache_torch_dtype)
|
kv_cache = kv_cache.view(kv_cache_torch_dtype)
|
||||||
@@ -341,7 +344,7 @@ def triton_reshape_and_cache_flash_diffkv(
|
|||||||
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
|
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
|
||||||
)
|
)
|
||||||
|
|
||||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
|
||||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
torch.float8_e5m2,
|
torch.float8_e5m2,
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ from vllm.utils.nvtx_pytorch_hooks import PytHooks
|
|||||||
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
|
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
|
||||||
from vllm.utils.torch_utils import (
|
from vllm.utils.torch_utils import (
|
||||||
get_dtype_size,
|
get_dtype_size,
|
||||||
|
is_quantized_kv_cache,
|
||||||
kv_cache_dtype_str_to_dtype,
|
kv_cache_dtype_str_to_dtype,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
@@ -896,7 +897,7 @@ class GPUModelRunner(
|
|||||||
If these are left at 0.0 (default after wake_up), all KV cache values
|
If these are left at 0.0 (default after wake_up), all KV cache values
|
||||||
become effectively zero, causing gibberish output.
|
become effectively zero, causing gibberish output.
|
||||||
"""
|
"""
|
||||||
if not self.cache_config.cache_dtype.startswith("fp8"):
|
if not is_quantized_kv_cache(self.cache_config.cache_dtype):
|
||||||
return
|
return
|
||||||
|
|
||||||
kv_caches = getattr(self, "kv_caches", [])
|
kv_caches = getattr(self, "kv_caches", [])
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from vllm.tasks import SupportedTask
|
|||||||
from vllm.tracing import instrument
|
from vllm.tracing import instrument
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
|
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
|
||||||
from vllm.utils.torch_utils import set_random_seed
|
from vllm.utils.torch_utils import is_quantized_kv_cache, set_random_seed
|
||||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import (
|
from vllm.v1.outputs import (
|
||||||
@@ -197,7 +197,7 @@ class Worker(WorkerBase):
|
|||||||
# especially the FP8 scaling factor.
|
# especially the FP8 scaling factor.
|
||||||
if (
|
if (
|
||||||
(tags is None or "kv_cache" in tags)
|
(tags is None or "kv_cache" in tags)
|
||||||
and self.cache_config.cache_dtype.startswith("fp8")
|
and is_quantized_kv_cache(self.cache_config.cache_dtype)
|
||||||
and hasattr(self.model_runner, "init_fp8_kv_scales")
|
and hasattr(self.model_runner, "init_fp8_kv_scales")
|
||||||
):
|
):
|
||||||
self.model_runner.init_fp8_kv_scales()
|
self.model_runner.init_fp8_kv_scales()
|
||||||
|
|||||||
Reference in New Issue
Block a user