[Optimization] Implement fused add rmsnorm (#1667)
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
"""Custom normalization layers."""
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -21,7 +23,19 @@ class RMSNorm(nn.Module):
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
layernorm_ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
|
||||
Reference in New Issue
Block a user