[Bugfix] Add missing activation attr to RMSNormGated (#35423)
Signed-off-by: tibG <naps@qubes.milou> Co-authored-by: tibG <naps@qubes.milou>
This commit is contained in:
@@ -510,6 +510,7 @@ class RMSNormGated(CustomOp):
|
||||
norm_before_gate: bool = False,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""Initialize RMSNormGated.
|
||||
|
||||
@@ -524,10 +525,12 @@ class RMSNormGated(CustomOp):
|
||||
If False and z is provided: out = norm(x * silu(z))
|
||||
device: Device to create parameters on
|
||||
dtype: Data type for parameters
|
||||
activation: Activation function name for gating
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.activation = activation
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
|
||||
Reference in New Issue
Block a user