Add PyTorch-native implementation of custom layers (#1898)

This commit is contained in:
Woosuk Kwon
2023-12-02 21:18:40 -08:00
committed by GitHub
parent 5313c2cb8b
commit 9b294976a2
6 changed files with 149 additions and 184 deletions

View File

@@ -23,6 +23,26 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def _forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def forward(
self,
x: torch.Tensor,