[Model] Add support for Gemma 3 (#14660)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Woosuk Kwon
2025-03-12 08:36:33 -07:00
committed by GitHub
parent 45f3f3f59e
commit c0c25e25fa
10 changed files with 1071 additions and 9 deletions

View File

@@ -350,10 +350,11 @@ class ModelConfig:
if self.enforce_eager is None:
self.enforce_eager = False
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
(self.hf_text_config.model_type in interleaved_attn_models))
if (not self.disable_sliding_window and has_interleaved_attention):
if (backend :=
@@ -2501,11 +2502,11 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16.")
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
"instead of float16 by default. Please specify `dtype` "
"if you want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
@@ -2637,7 +2638,9 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
# scaling, so we skip applying the scaling factor again.
if rope_scaling is not None and "gemma3" not in hf_config.model_type:
# No need to consider "type" key because of patch_rope_scaling when
# loading HF config
rope_type = rope_scaling["rope_type"]