Make Gemma and Gemma 2 accept inputs_embeds like Gemma 3 (#36787)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,6 +11,8 @@ from unittest.mock import Mock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from packaging.version import Version
|
||||||
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -91,6 +93,15 @@ def test_models(
|
|||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||||
|
if model == "hmellor/tiny-random-Gemma2ForCausalLM" and (
|
||||||
|
Version(TRANSFORMERS_VERSION) < Version("5.3.0.dev0")
|
||||||
|
):
|
||||||
|
# For Gemma 1/2 models with Transformers 5.4.0+, the prompt embeddings
|
||||||
|
# are normalised in `get_prompt_embeddings`, like Gemma 3.
|
||||||
|
# For older versions, we need to manually normalise.
|
||||||
|
embed_scale = hf_model.config.hidden_size**0.5
|
||||||
|
normalizer = torch.tensor(embed_scale, dtype=prompt_embeds[0].dtype)
|
||||||
|
prompt_embeds = [p_e * normalizer for p_e in prompt_embeds]
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from packaging.version import Version
|
||||||
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -151,6 +153,16 @@ def test_models(
|
|||||||
if prompt_embeds is not None:
|
if prompt_embeds is not None:
|
||||||
embed = hf_model.model.get_input_embeddings()(token_ids)
|
embed = hf_model.model.get_input_embeddings()(token_ids)
|
||||||
|
|
||||||
|
if "gemma" in model.lower() and (
|
||||||
|
Version(TRANSFORMERS_VERSION) < Version("5.3.0.dev0")
|
||||||
|
):
|
||||||
|
# For Gemma 1/2 models with Transformers 5.4.0+, the prompt
|
||||||
|
# embeddings are normalised in `get_prompt_embeddings`,
|
||||||
|
# like Gemma 3. For older versions, we need to manually normalise.
|
||||||
|
embed_scale = hf_model.config.hidden_size**0.5
|
||||||
|
normalizer = torch.tensor(embed_scale, dtype=embed.dtype)
|
||||||
|
embed *= normalizer
|
||||||
|
|
||||||
# MiniCPM models apply scale_emb to embeddings internally.
|
# MiniCPM models apply scale_emb to embeddings internally.
|
||||||
# vLLM expects pre-scaled embeddings when using inputs_embeds.
|
# vLLM expects pre-scaled embeddings when using inputs_embeds.
|
||||||
if model in EMBED_SCALING_MODELS:
|
if model in EMBED_SCALING_MODELS:
|
||||||
|
|||||||
@@ -293,7 +293,7 @@ class GemmaModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids) * self.normalizer
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -307,7 +307,6 @@ class GemmaModel(nn.Module):
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
hidden_states = self.embed_input_ids(input_ids)
|
hidden_states = self.embed_input_ids(input_ids)
|
||||||
hidden_states *= self.normalizer
|
|
||||||
residual = None
|
residual = None
|
||||||
else:
|
else:
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|||||||
@@ -284,7 +284,7 @@ class Gemma2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids) * self.normalizer
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -298,7 +298,6 @@ class Gemma2Model(nn.Module):
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
hidden_states = self.embed_input_ids(input_ids)
|
hidden_states = self.embed_input_ids(input_ids)
|
||||||
hidden_states *= self.normalizer
|
|
||||||
residual = None
|
residual = None
|
||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user