Add kernel for GeGLU with approximate GELU (#3337)
This commit is contained in:
@@ -47,16 +47,25 @@ class GeluAndMul(nn.Module):
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def __init__(self, approximate: str = "none"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d]) * x[..., d:]
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.gelu_and_mul(out, x)
|
||||
if self.approximate == "none":
|
||||
ops.gelu_and_mul(out, x)
|
||||
elif self.approximate == "tanh":
|
||||
ops.gelu_tanh_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user