Make Gemma and Gemma 2 accept inputs_embeds like Gemma 3 (#36787)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-11 18:12:43 +00:00
committed by GitHub
parent 9556af87d5
commit 65986db6ba
4 changed files with 25 additions and 4 deletions

View File

@@ -3,6 +3,8 @@
import pytest
import torch
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.platforms import current_platform
@@ -151,6 +153,16 @@ def test_models(
if prompt_embeds is not None:
embed = hf_model.model.get_input_embeddings()(token_ids)
if "gemma" in model.lower() and (
Version(TRANSFORMERS_VERSION) < Version("5.3.0.dev0")
):
# For Gemma 1/2 models with Transformers 5.4.0+, the prompt
# embeddings are normalised in `get_prompt_embeddings`,
# like Gemma 3. For older versions, we need to manually normalise.
embed_scale = hf_model.config.hidden_size**0.5
normalizer = torch.tensor(embed_scale, dtype=embed.dtype)
embed *= normalizer
# MiniCPM models apply scale_emb to embeddings internally.
# vLLM expects pre-scaled embeddings when using inputs_embeds.
if model in EMBED_SCALING_MODELS: