[Model] Add Gemma 2 (#5908)

This commit is contained in:
Woosuk Kwon
2024-06-27 13:33:56 -07:00
committed by GitHub
parent 736ed38849
commit 79c92c7c8a
9 changed files with 499 additions and 9 deletions

View File

@@ -22,7 +22,8 @@ class LogitsProcessor(nn.Module):
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False) -> None:
logits_as_input: bool = False,
soft_cap: Optional[float] = None) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
@@ -34,6 +35,8 @@ class LogitsProcessor(nn.Module):
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
def forward(
self,
@@ -52,6 +55,11 @@ class LogitsProcessor(nn.Module):
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap
if self.scale != 1.0:
logits *= self.scale