[Kernel] Add cuda kernel for gpt_oss activation (#22951)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -1621,17 +1621,6 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias)
|
||||
|
||||
# TODO fused kernel
|
||||
def swiglu_oai(gate_up):
|
||||
alpha = 1.702
|
||||
limit = 7.0
|
||||
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
glu = gate * torch.sigmoid(gate * alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
# Activation function with multiplication
|
||||
if activation == "silu" and is_act_and_mul:
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
@@ -1639,13 +1628,16 @@ def fused_experts_impl(
|
||||
elif activation == "gelu" and is_act_and_mul:
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "swigluoai" and is_act_and_mul:
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
# Activation function without multiplication
|
||||
elif activation == "silu":
|
||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "swiglu_oai":
|
||||
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
|
||||
f"with is_act_and_mul={is_act_and_mul}.")
|
||||
|
||||
Reference in New Issue
Block a user