[Misc] rename torch_dtype to dtype (#26695)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-10-15 20:11:48 +08:00
committed by GitHub
parent f93e348010
commit 8f4b313c37
30 changed files with 52 additions and 55 deletions

View File

@@ -1837,18 +1837,18 @@ def _find_dtype(
*,
revision: str | None,
):
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
# NOTE: getattr(config, "dtype", torch.float32) is not correct
# because config.dtype can be None.
config_dtype = getattr(config, "dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define torch_dtype
# does not define dtype
if config_dtype is None:
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
config_dtype = getattr(config.get_text_config(), "dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
config_dtype = getattr(config.vision_config, "dtype", None)
if config_dtype is None and hasattr(config, "encoder_config"):
config_dtype = getattr(config.encoder_config, "torch_dtype", None)
config_dtype = getattr(config.encoder_config, "dtype", None)
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None: