[Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) (#6611)
This commit is contained in:
@@ -159,6 +159,21 @@ class QuickGELU(CustomOp):
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
relu_applied = nn.functional.relu(x)
|
||||
squared = torch.square(relu_applied)
|
||||
return squared
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
@@ -207,6 +222,7 @@ _ACTIVATION_REGISTRY = {
|
||||
"gelu_new": NewGELU(),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
"relu2": ReLUSquaredActivation(),
|
||||
"quick_gelu": QuickGELU(),
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user