[UX] Infer dtype for local checkpoint (#36218)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -1116,7 +1116,7 @@ def get_safetensors_params_metadata(
|
||||
revision: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the safetensors metadata for remote model repository.
|
||||
Get the safetensors parameters metadata for remote/local model repository.
|
||||
"""
|
||||
full_metadata = {}
|
||||
if (model_path := Path(model)).exists():
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.config.utils import getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat,
|
||||
try_get_safetensors_metadata,
|
||||
get_safetensors_params_metadata,
|
||||
)
|
||||
from vllm.utils.torch_utils import common_broadcastable_dtype
|
||||
|
||||
@@ -165,14 +165,14 @@ class ModelArchConfigConvertorBase:
|
||||
# Try to read the dtype of the weights if they are in safetensors format
|
||||
if config_dtype is None:
|
||||
with _maybe_patch_hf_hub_constants(config_format):
|
||||
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
|
||||
param_mt = get_safetensors_params_metadata(model_id, revision=revision)
|
||||
|
||||
if repo_mt and (files_mt := repo_mt.files_metadata):
|
||||
if param_mt:
|
||||
param_dtypes: set[torch.dtype] = {
|
||||
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
|
||||
for file_mt in files_mt.values()
|
||||
for dtype_str in file_mt.parameter_count
|
||||
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
|
||||
_SAFETENSORS_TO_TORCH_DTYPE[dtype]
|
||||
for info in param_mt.values()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and dtype in _SAFETENSORS_TO_TORCH_DTYPE
|
||||
}
|
||||
|
||||
if param_dtypes:
|
||||
|
||||
Reference in New Issue
Block a user