[Bugfix] Add missing activation attr to RMSNormGated (#35423)
Signed-off-by: tibG <naps@qubes.milou> Co-authored-by: tibG <naps@qubes.milou>
This commit is contained in:
@@ -510,6 +510,7 @@ class RMSNormGated(CustomOp):
|
|||||||
norm_before_gate: bool = False,
|
norm_before_gate: bool = False,
|
||||||
device: torch.device | None = None,
|
device: torch.device | None = None,
|
||||||
dtype: torch.dtype | None = None,
|
dtype: torch.dtype | None = None,
|
||||||
|
activation: str = "swish",
|
||||||
):
|
):
|
||||||
"""Initialize RMSNormGated.
|
"""Initialize RMSNormGated.
|
||||||
|
|
||||||
@@ -524,10 +525,12 @@ class RMSNormGated(CustomOp):
|
|||||||
If False and z is provided: out = norm(x * silu(z))
|
If False and z is provided: out = norm(x * silu(z))
|
||||||
device: Device to create parameters on
|
device: Device to create parameters on
|
||||||
dtype: Data type for parameters
|
dtype: Data type for parameters
|
||||||
|
activation: Activation function name for gating
|
||||||
"""
|
"""
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.activation = activation
|
||||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
|||||||
Reference in New Issue
Block a user