Optimize GeGLU layer in Gemma (#2975)
This commit is contained in:
@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class GeluAndMul(nn.Module):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.gelu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user