[Chore] Replace swish with silu (#32459)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -36,10 +36,13 @@ def get_activation(name: str = "relu") -> torch.nn.Module:
|
||||
if name == "gelu":
|
||||
return nn.GELU()
|
||||
if name == "swish":
|
||||
return Swish()
|
||||
return nn.SiLU()
|
||||
if name == "sigmoid":
|
||||
return torch.nn.Sigmoid()
|
||||
return nn.Identity()
|
||||
return nn.Sigmoid()
|
||||
if name == "identity":
|
||||
return nn.Identity()
|
||||
|
||||
raise NotImplementedError(name)
|
||||
|
||||
|
||||
def adaptive_enc_mask(
|
||||
@@ -93,44 +96,14 @@ def adaptive_enc_mask(
|
||||
return mask_left & mask_right
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Implement Swish activation module.
|
||||
From https://arxiv.org/pdf/2005.03191.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.act_fn = nn.Sigmoid()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Apply Swish function
|
||||
|
||||
Args:
|
||||
x: torch.Tensor
|
||||
Input.
|
||||
"""
|
||||
return x * self.act_fn(x)
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
"""Implement Gated Linear Unit (GLU) module"""
|
||||
|
||||
def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.act_name = act_name.lower()
|
||||
|
||||
if self.act_name == "relu":
|
||||
self.act_fn = nn.ReLU(inplace=True)
|
||||
elif self.act_name == "gelu":
|
||||
self.act_fn = nn.GELU()
|
||||
elif self.act_name == "swish":
|
||||
self.act_fn = Swish()
|
||||
elif self.act_name == "sigmoid":
|
||||
self.act_fn = nn.Sigmoid()
|
||||
else:
|
||||
self.act_fn = nn.Identity()
|
||||
self.dim = dim
|
||||
self.act_fn = get_activation(act_name)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""GLU forward
|
||||
@@ -204,16 +177,7 @@ class GLUPointWiseConv(nn.Module):
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
|
||||
if glu_type == "sigmoid":
|
||||
self.glu_act = nn.Sigmoid()
|
||||
elif glu_type == "relu":
|
||||
self.glu_act = nn.ReLU()
|
||||
elif glu_type == "gelu":
|
||||
self.glu_act = nn.GELU()
|
||||
elif glu_type == "swish":
|
||||
self.glu_act = Swish()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation type {self.glu_act}")
|
||||
self.glu_act = get_activation(glu_type)
|
||||
|
||||
if bias_in_glu:
|
||||
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
||||
|
||||
Reference in New Issue
Block a user