[Model] Add Gemma 2 (#5908)
This commit is contained in:
@@ -22,7 +22,8 @@ class LogitsProcessor(nn.Module):
|
||||
vocab_size: int,
|
||||
org_vocab_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
logits_as_input: bool = False) -> None:
|
||||
logits_as_input: bool = False,
|
||||
soft_cap: Optional[float] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
scale: A scaling factor to apply to the logits.
|
||||
@@ -34,6 +35,8 @@ class LogitsProcessor(nn.Module):
|
||||
self.logits_as_input = logits_as_input
|
||||
# original vocabulary size (without LoRA).
|
||||
self.org_vocab_size = org_vocab_size or vocab_size
|
||||
# Soft cap the logits. Used in Gemma 2.
|
||||
self.soft_cap = soft_cap
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -52,6 +55,11 @@ class LogitsProcessor(nn.Module):
|
||||
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||
|
||||
if logits is not None:
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.soft_cap
|
||||
|
||||
if self.scale != 1.0:
|
||||
logits *= self.scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user