[Misc] rename torch_dtype to dtype (#26695)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1837,18 +1837,18 @@ def _find_dtype(
|
||||
*,
|
||||
revision: str | None,
|
||||
):
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
# NOTE: getattr(config, "dtype", torch.float32) is not correct
|
||||
# because config.dtype can be None.
|
||||
config_dtype = getattr(config, "dtype", None)
|
||||
|
||||
# Fallbacks for multi-modal models if the root config
|
||||
# does not define torch_dtype
|
||||
# does not define dtype
|
||||
if config_dtype is None:
|
||||
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
|
||||
config_dtype = getattr(config.get_text_config(), "dtype", None)
|
||||
if config_dtype is None and hasattr(config, "vision_config"):
|
||||
config_dtype = getattr(config.vision_config, "torch_dtype", None)
|
||||
config_dtype = getattr(config.vision_config, "dtype", None)
|
||||
if config_dtype is None and hasattr(config, "encoder_config"):
|
||||
config_dtype = getattr(config.encoder_config, "torch_dtype", None)
|
||||
config_dtype = getattr(config.encoder_config, "dtype", None)
|
||||
|
||||
# Try to read the dtype of the weights if they are in safetensors format
|
||||
if config_dtype is None:
|
||||
|
||||
Reference in New Issue
Block a user