[Fix]Load kv-cache dtype from hf_quant_config.json automatically (fix for reverted PR) (#30785)
Signed-off-by: <> Co-authored-by: root <root@gpu-937.slurm-workers-slurm.slurm.svc.cluster.local>
This commit is contained in:
@@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage
|
|||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.utils.network_utils import get_ip
|
from vllm.utils.network_utils import get_ip
|
||||||
|
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -106,6 +107,7 @@ else:
|
|||||||
LoadFormats = Any
|
LoadFormats = Any
|
||||||
UsageContext = Any
|
UsageContext = Any
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# object is used to allow for special typing forms
|
# object is used to allow for special typing forms
|
||||||
@@ -1361,12 +1363,17 @@ class EngineArgs:
|
|||||||
f"dcp_size={self.decode_context_parallel_size}."
|
f"dcp_size={self.decode_context_parallel_size}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Resolve "auto" kv_cache_dtype to actual value from model config
|
||||||
|
resolved_cache_dtype = resolve_kv_cache_dtype_string(
|
||||||
|
self.kv_cache_dtype, model_config
|
||||||
|
)
|
||||||
|
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=self.block_size,
|
block_size=self.block_size,
|
||||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||||
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
|
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
|
||||||
swap_space=self.swap_space,
|
swap_space=self.swap_space,
|
||||||
cache_dtype=self.kv_cache_dtype,
|
cache_dtype=resolved_cache_dtype,
|
||||||
is_attention_free=model_config.is_attention_free,
|
is_attention_free=model_config.is_attention_free,
|
||||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ else:
|
|||||||
ModelConfig = object
|
ModelConfig = object
|
||||||
IntermediateTensors = object
|
IntermediateTensors = object
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"float32": torch.float32,
|
"float32": torch.float32,
|
||||||
@@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
|
||||||
|
# TODO: Add more modelopt kv cache dtype
|
||||||
|
# mappings here when it supported by some attention backend
|
||||||
|
# (for example supports nvfp4).
|
||||||
|
"fp8": "fp8_e4m3",
|
||||||
|
}
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype(
|
|||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None:
|
||||||
|
"""Get the KV cache quantization algorithm string from the quantization config.
|
||||||
|
|
||||||
|
Maps various FP8 format names to vLLM's standard cache dtype strings.
|
||||||
|
Returns None if no kv_cache_quant_algo is specified.
|
||||||
|
Returns "auto" if the value is not recognized/supported.
|
||||||
|
"""
|
||||||
|
# Mapping from model config values to vLLM cache_dtype strings
|
||||||
|
|
||||||
|
quant_method = quant_cfg.get("quant_method", "")
|
||||||
|
if quant_method.startswith("modelopt"):
|
||||||
|
quantization_inner = quant_cfg.get("quantization", quant_cfg)
|
||||||
|
# Check if quant config is specified and use kv cache quant algo
|
||||||
|
kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get(
|
||||||
|
"kv_cache_quant_algo"
|
||||||
|
)
|
||||||
|
if isinstance(kv_algo, str):
|
||||||
|
kv_algo_lower = kv_algo.lower()
|
||||||
|
|
||||||
|
# Try to map to vLLM's standard format
|
||||||
|
if kv_algo_lower in MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP:
|
||||||
|
return MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP[kv_algo_lower]
|
||||||
|
else:
|
||||||
|
# Unknown/unsupported format - return "auto" as safe fallback
|
||||||
|
logger.warning(
|
||||||
|
"WARNING: Unknown kv_cache_quant_algo '%s' in model "
|
||||||
|
"config. Supported values: %s. Falling back to 'auto'.",
|
||||||
|
kv_algo,
|
||||||
|
list(MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP.keys()),
|
||||||
|
)
|
||||||
|
return "auto"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None:
|
||||||
|
"""Get the KV cache quantization algorithm dtype from the quantization config."""
|
||||||
|
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
|
||||||
|
if kv_algo_str is not None and kv_algo_str != "auto":
|
||||||
|
# Only convert if we have a valid dtype string (not "auto" fallback)
|
||||||
|
return STR_DTYPE_TO_TORCH_DTYPE[kv_algo_str]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_kv_cache_dtype_string(
|
||||||
|
kv_cache_dtype: str, model_config: ModelConfig
|
||||||
|
) -> str:
|
||||||
|
"""Resolve 'auto' kv_cache_dtype to the actual string value from model config.
|
||||||
|
Returns the resolved cache_dtype string.
|
||||||
|
"""
|
||||||
|
if kv_cache_dtype != "auto":
|
||||||
|
return kv_cache_dtype
|
||||||
|
|
||||||
|
hf_cfg = getattr(model_config, "hf_config", None)
|
||||||
|
if hf_cfg is not None:
|
||||||
|
quant_cfg = getattr(hf_cfg, "quantization_config", None)
|
||||||
|
if quant_cfg is not None:
|
||||||
|
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
|
||||||
|
if kv_algo_str is not None:
|
||||||
|
return kv_algo_str
|
||||||
|
|
||||||
|
# Default to auto (will be handled by downstream code)
|
||||||
|
return "auto"
|
||||||
|
|
||||||
|
|
||||||
def kv_cache_dtype_str_to_dtype(
|
def kv_cache_dtype_str_to_dtype(
|
||||||
kv_cache_dtype: str, model_config: ModelConfig
|
kv_cache_dtype: str, model_config: ModelConfig
|
||||||
) -> torch.dtype:
|
) -> torch.dtype:
|
||||||
|
|||||||
Reference in New Issue
Block a user