From 65986db6ba71abf4cf0639c5fd1477b0d8df8f5e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 11 Mar 2026 18:12:43 +0000 Subject: [PATCH] Make Gemma and Gemma 2 accept `inputs_embeds` like Gemma 3 (#36787) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/basic_correctness/test_basic_correctness.py | 11 +++++++++++ tests/models/language/generation/test_common.py | 12 ++++++++++++ vllm/model_executor/models/gemma.py | 3 +-- vllm/model_executor/models/gemma2.py | 3 +-- 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 70c58ad96..1a07ac6da 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -11,6 +11,8 @@ from unittest.mock import Mock import pytest import torch +from packaging.version import Version +from transformers import __version__ as TRANSFORMERS_VERSION from vllm import LLM from vllm.platforms import current_platform @@ -91,6 +93,15 @@ def test_models( if enable_prompt_embeds: with torch.no_grad(): prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) + if model == "hmellor/tiny-random-Gemma2ForCausalLM" and ( + Version(TRANSFORMERS_VERSION) < Version("5.3.0.dev0") + ): + # For Gemma 1/2 models with Transformers 5.4.0+, the prompt embeddings + # are normalised in `get_prompt_embeddings`, like Gemma 3. + # For older versions, we need to manually normalise. + embed_scale = hf_model.config.hidden_size**0.5 + normalizer = torch.tensor(embed_scale, dtype=prompt_embeds[0].dtype) + prompt_embeds = [p_e * normalizer for p_e in prompt_embeds] with VllmRunner( model, diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 474d71797..ec8949b00 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -3,6 +3,8 @@ import pytest import torch +from packaging.version import Version +from transformers import __version__ as TRANSFORMERS_VERSION from vllm.platforms import current_platform @@ -151,6 +153,16 @@ def test_models( if prompt_embeds is not None: embed = hf_model.model.get_input_embeddings()(token_ids) + if "gemma" in model.lower() and ( + Version(TRANSFORMERS_VERSION) < Version("5.3.0.dev0") + ): + # For Gemma 1/2 models with Transformers 5.4.0+, the prompt + # embeddings are normalised in `get_prompt_embeddings`, + # like Gemma 3. For older versions, we need to manually normalise. + embed_scale = hf_model.config.hidden_size**0.5 + normalizer = torch.tensor(embed_scale, dtype=embed.dtype) + embed *= normalizer + # MiniCPM models apply scale_emb to embeddings internally. # vLLM expects pre-scaled embeddings when using inputs_embeds. if model in EMBED_SCALING_MODELS: diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index b3ae5f5ac..6e35020a6 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -293,7 +293,7 @@ class GemmaModel(nn.Module): ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + return self.embed_tokens(input_ids) * self.normalizer def forward( self, @@ -307,7 +307,6 @@ class GemmaModel(nn.Module): hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) - hidden_states *= self.normalizer residual = None else: hidden_states = intermediate_tensors["hidden_states"] diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 3b0a6a492..425ecc651 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -284,7 +284,7 @@ class Gemma2Model(nn.Module): ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + return self.embed_tokens(input_ids) * self.normalizer def forward( self, @@ -298,7 +298,6 @@ class Gemma2Model(nn.Module): hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) - hidden_states *= self.normalizer residual = None else: assert intermediate_tensors is not None