[Fix]Load kv-cache dtype from hf_quant_config.json automatically (fix for reverted PR) (#30785)

Signed-off-by: <>
Co-authored-by: root <root@gpu-937.slurm-workers-slurm.slurm.svc.cluster.local>
This commit is contained in:
danielafrimi
2025-12-17 11:56:38 +02:00
committed by GitHub
parent 9db1db5949
commit 7b966ae2ba
2 changed files with 83 additions and 1 deletions

View File

@@ -24,6 +24,10 @@ else:
ModelConfig = object
IntermediateTensors = object
import logging
logger = logging.getLogger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
@@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
# TODO: Add more modelopt kv cache dtype
# mappings here when it supported by some attention backend
# (for example supports nvfp4).
"fp8": "fp8_e4m3",
}
T = TypeVar("T")
@@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype(
return torch_dtype
def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None:
"""Get the KV cache quantization algorithm string from the quantization config.
Maps various FP8 format names to vLLM's standard cache dtype strings.
Returns None if no kv_cache_quant_algo is specified.
Returns "auto" if the value is not recognized/supported.
"""
# Mapping from model config values to vLLM cache_dtype strings
quant_method = quant_cfg.get("quant_method", "")
if quant_method.startswith("modelopt"):
quantization_inner = quant_cfg.get("quantization", quant_cfg)
# Check if quant config is specified and use kv cache quant algo
kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get(
"kv_cache_quant_algo"
)
if isinstance(kv_algo, str):
kv_algo_lower = kv_algo.lower()
# Try to map to vLLM's standard format
if kv_algo_lower in MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP:
return MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP[kv_algo_lower]
else:
# Unknown/unsupported format - return "auto" as safe fallback
logger.warning(
"WARNING: Unknown kv_cache_quant_algo '%s' in model "
"config. Supported values: %s. Falling back to 'auto'.",
kv_algo,
list(MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP.keys()),
)
return "auto"
return None
def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None:
"""Get the KV cache quantization algorithm dtype from the quantization config."""
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
if kv_algo_str is not None and kv_algo_str != "auto":
# Only convert if we have a valid dtype string (not "auto" fallback)
return STR_DTYPE_TO_TORCH_DTYPE[kv_algo_str]
return None
def resolve_kv_cache_dtype_string(
kv_cache_dtype: str, model_config: ModelConfig
) -> str:
"""Resolve 'auto' kv_cache_dtype to the actual string value from model config.
Returns the resolved cache_dtype string.
"""
if kv_cache_dtype != "auto":
return kv_cache_dtype
hf_cfg = getattr(model_config, "hf_config", None)
if hf_cfg is not None:
quant_cfg = getattr(hf_cfg, "quantization_config", None)
if quant_cfg is not None:
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
if kv_algo_str is not None:
return kv_algo_str
# Default to auto (will be handled by downstream code)
return "auto"
def kv_cache_dtype_str_to_dtype(
kv_cache_dtype: str, model_config: ModelConfig
) -> torch.dtype: