[Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) (#6611)

This commit is contained in:
Michael Goin
2024-07-26 14:33:42 -04:00
committed by GitHub
parent 85ad7e2d01
commit 07278c37dd
9 changed files with 776 additions and 1 deletions

View File

@@ -159,6 +159,21 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
relu_applied = nn.functional.relu(x)
squared = torch.square(relu_applied)
return squared
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
@@ -207,6 +222,7 @@ _ACTIVATION_REGISTRY = {
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(),
}