[Optimization] Implement fused add rmsnorm (#1667)

This commit is contained in:
ljss
2023-11-19 10:18:02 +08:00
committed by GitHub
parent 8d17774f92
commit e1054247ba
9 changed files with 166 additions and 61 deletions

View File

@@ -1,4 +1,6 @@
"""Custom normalization layers."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -21,7 +23,19 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
layernorm_ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
layernorm_ops.rms_norm(
out,