[Kernel] Add cuda kernel for gpt_oss activation (#22951)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -239,6 +239,35 @@ class GeluAndMul(CustomOp):
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
@CustomOp.register("swigluoai_and_mul")
|
||||
class SwigluOAIAndMul(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
|
||||
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.limit = limit
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
|
||||
gate, up = x[..., ::2], x[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
|
||||
|
||||
|
||||
@CustomOp.register("gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
@@ -330,6 +359,7 @@ class ReLUSquaredActivation(CustomOp):
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
#TODO : implement cuda kenrels
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@@ -406,9 +436,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
"gelu": lambda: GeluAndMul(),
|
||||
"silu": lambda: SiluAndMul(),
|
||||
"geglu": lambda: GeluAndMul(),
|
||||
"gelu":
|
||||
lambda: GeluAndMul(),
|
||||
"silu":
|
||||
lambda: SiluAndMul(),
|
||||
"geglu":
|
||||
lambda: GeluAndMul(),
|
||||
"swigluoai":
|
||||
lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
|
||||
})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user