Optimize GeGLU layer in Gemma (#2975)

This commit is contained in:
Woosuk Kwon
2024-02-21 20:17:52 -08:00
committed by GitHub
parent 93dc5a2870
commit fd5dcc5c81
6 changed files with 107 additions and 76 deletions

View File

@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
return out
class GeluAndMul(nn.Module):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.gelu_and_mul(out, x)
return out
class NewGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor: