[Core] Update dtype detection and defaults (#14858)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -347,7 +347,7 @@ class ModelConfig:
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self.use_async_output_proc = use_async_output_proc
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
|
||||
@@ -2526,6 +2526,14 @@ def _get_and_verify_dtype(
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
|
||||
# Fallbacks for multi-modal models if the root config
|
||||
# does not define torch_dtype
|
||||
if config_dtype is None and hasattr(config, "text_config"):
|
||||
config_dtype = getattr(config.text_config, "torch_dtype", None)
|
||||
if config_dtype is None and hasattr(config, "vision_config"):
|
||||
config_dtype = getattr(config.vision_config, "torch_dtype", None)
|
||||
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
@@ -2533,16 +2541,8 @@ def _get_and_verify_dtype(
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
|
||||
logger.info(
|
||||
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
|
||||
"instead of float16 by default. Please specify `dtype` "
|
||||
"if you want to use float16.")
|
||||
torch_dtype = torch.bfloat16
|
||||
else:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
torch_dtype = torch.float16
|
||||
# Following common practice, we use float16 for float32 models
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
|
||||
|
||||
Reference in New Issue
Block a user