Make Gemma and Gemma 2 accept inputs_embeds like Gemma 3 (#36787)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,6 +11,8 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.platforms import current_platform
|
||||
@@ -91,6 +93,15 @@ def test_models(
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
if model == "hmellor/tiny-random-Gemma2ForCausalLM" and (
|
||||
Version(TRANSFORMERS_VERSION) < Version("5.3.0.dev0")
|
||||
):
|
||||
# For Gemma 1/2 models with Transformers 5.4.0+, the prompt embeddings
|
||||
# are normalised in `get_prompt_embeddings`, like Gemma 3.
|
||||
# For older versions, we need to manually normalise.
|
||||
embed_scale = hf_model.config.hidden_size**0.5
|
||||
normalizer = torch.tensor(embed_scale, dtype=prompt_embeds[0].dtype)
|
||||
prompt_embeds = [p_e * normalizer for p_e in prompt_embeds]
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
|
||||
Reference in New Issue
Block a user