[Chore] Replace swish with silu (#32459)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-16 16:22:45 +08:00
committed by GitHub
parent b84c426a8c
commit 180e981d56

View File

@@ -36,10 +36,13 @@ def get_activation(name: str = "relu") -> torch.nn.Module:
if name == "gelu":
return nn.GELU()
if name == "swish":
return Swish()
return nn.SiLU()
if name == "sigmoid":
return torch.nn.Sigmoid()
return nn.Identity()
return nn.Sigmoid()
if name == "identity":
return nn.Identity()
raise NotImplementedError(name)
def adaptive_enc_mask(
@@ -93,44 +96,14 @@ def adaptive_enc_mask(
return mask_left & mask_right
class Swish(nn.Module):
"""Implement Swish activation module.
From https://arxiv.org/pdf/2005.03191.pdf
"""
def __init__(self) -> None:
super().__init__()
self.act_fn = nn.Sigmoid()
def forward(self, x: Tensor) -> Tensor:
"""Apply Swish function
Args:
x: torch.Tensor
Input.
"""
return x * self.act_fn(x)
class GLU(nn.Module):
"""Implement Gated Linear Unit (GLU) module"""
def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
super().__init__()
self.dim = dim
self.act_name = act_name.lower()
if self.act_name == "relu":
self.act_fn = nn.ReLU(inplace=True)
elif self.act_name == "gelu":
self.act_fn = nn.GELU()
elif self.act_name == "swish":
self.act_fn = Swish()
elif self.act_name == "sigmoid":
self.act_fn = nn.Sigmoid()
else:
self.act_fn = nn.Identity()
self.dim = dim
self.act_fn = get_activation(act_name)
def forward(self, x: Tensor) -> Tensor:
"""GLU forward
@@ -204,16 +177,7 @@ class GLUPointWiseConv(nn.Module):
padding=(kernel_size - 1) // 2,
)
if glu_type == "sigmoid":
self.glu_act = nn.Sigmoid()
elif glu_type == "relu":
self.glu_act = nn.ReLU()
elif glu_type == "gelu":
self.glu_act = nn.GELU()
elif glu_type == "swish":
self.glu_act = Swish()
else:
raise ValueError(f"Unsupported activation type {self.glu_act}")
self.glu_act = get_activation(glu_type)
if bias_in_glu:
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))