[Model] Add support for Gemma 3 (#14660)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -350,10 +350,11 @@ class ModelConfig:
|
||||
if self.enforce_eager is None:
|
||||
self.enforce_eager = False
|
||||
|
||||
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
|
||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||
has_interleaved_attention = (sliding_window is not None) and (
|
||||
isinstance(sliding_window, list) or
|
||||
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
|
||||
(self.hf_text_config.model_type in interleaved_attn_models))
|
||||
|
||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
||||
if (backend :=
|
||||
@@ -2501,11 +2502,11 @@ def _get_and_verify_dtype(
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type == "gemma2":
|
||||
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
|
||||
logger.info(
|
||||
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16.")
|
||||
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
|
||||
"instead of float16 by default. Please specify `dtype` "
|
||||
"if you want to use float16.")
|
||||
torch_dtype = torch.bfloat16
|
||||
else:
|
||||
# Following the common practice, we use float16 for float32
|
||||
@@ -2637,7 +2638,9 @@ def _get_and_verify_max_len(
|
||||
derived_max_model_len = default_max_len
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None:
|
||||
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
|
||||
# scaling, so we skip applying the scaling factor again.
|
||||
if rope_scaling is not None and "gemma3" not in hf_config.model_type:
|
||||
# No need to consider "type" key because of patch_rope_scaling when
|
||||
# loading HF config
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
|
||||
Reference in New Issue
Block a user