Make Gemma and Gemma 2 accept inputs_embeds like Gemma 3 (#36787)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-11 18:12:43 +00:00
committed by GitHub
parent 9556af87d5
commit 65986db6ba
4 changed files with 25 additions and 4 deletions

View File

@@ -284,7 +284,7 @@ class Gemma2Model(nn.Module):
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
return self.embed_tokens(input_ids) * self.normalizer
def forward(
self,
@@ -298,7 +298,6 @@ class Gemma2Model(nn.Module):
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.normalizer
residual = None
else:
assert intermediate_tensors is not None