[Model] Add Gemma 2 (#5908)
This commit is contained in:
@@ -95,3 +95,49 @@ class RMSNorm(CustomOp):
|
||||
s = f"hidden_size={self.weight.data.size(0)}"
|
||||
s += f", eps={self.variance_epsilon}"
|
||||
return s
|
||||
|
||||
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
|
||||
Two differences from the above RMSNorm:
|
||||
1. x * (1 + w) instead of x * w.
|
||||
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
x = x * (1.0 + self.weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
Reference in New Issue
Block a user