From e568cf88bc65531a95403110b186cd54dbfdc0e6 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 11 Mar 2026 16:50:04 +0800 Subject: [PATCH] [UX] Infer dtype for local checkpoint (#36218) Signed-off-by: Isotr0py --- vllm/transformers_utils/config.py | 2 +- .../model_arch_config_convertor.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index dd22ed544..fc8d377da 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1116,7 +1116,7 @@ def get_safetensors_params_metadata( revision: str | None = None, ) -> dict[str, Any]: """ - Get the safetensors metadata for remote model repository. + Get the safetensors parameters metadata for remote/local model repository. """ full_metadata = {} if (model_path := Path(model)).exists(): diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 4444469dc..3aeb37502 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -18,7 +18,7 @@ from vllm.config.utils import getattr_iter from vllm.logger import init_logger from vllm.transformers_utils.config import ( ConfigFormat, - try_get_safetensors_metadata, + get_safetensors_params_metadata, ) from vllm.utils.torch_utils import common_broadcastable_dtype @@ -165,14 +165,14 @@ class ModelArchConfigConvertorBase: # Try to read the dtype of the weights if they are in safetensors format if config_dtype is None: with _maybe_patch_hf_hub_constants(config_format): - repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + param_mt = get_safetensors_params_metadata(model_id, revision=revision) - if repo_mt and (files_mt := repo_mt.files_metadata): + if param_mt: param_dtypes: set[torch.dtype] = { - _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] - for file_mt in files_mt.values() - for dtype_str in file_mt.parameter_count - if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + _SAFETENSORS_TO_TORCH_DTYPE[dtype] + for info in param_mt.values() + if (dtype := info.get("dtype", None)) + and dtype in _SAFETENSORS_TO_TORCH_DTYPE } if param_dtypes: