[Core] Update dtype detection and defaults (#14858)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-19 13:49:33 +08:00
committed by GitHub
parent 8b3e94a357
commit f690372b68
22 changed files with 175 additions and 227 deletions

View File

@@ -347,7 +347,7 @@ class ModelConfig:
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
@@ -2526,6 +2526,14 @@ def _get_and_verify_dtype(
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
@@ -2533,16 +2541,8 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
logger.info(
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
"instead of float16 by default. Please specify `dtype` "
"if you want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
# Following common practice, we use float16 for float32 models
torch_dtype = torch.float16
else:
torch_dtype = config_dtype