[Bugfix] Add missing activation attr to RMSNormGated (#35423)

Signed-off-by: tibG <naps@qubes.milou>
Co-authored-by: tibG <naps@qubes.milou>
This commit is contained in:
Tib
2026-02-27 13:53:35 +01:00
committed by GitHub
parent 9c3fe9936b
commit 6467b635b6

View File

@@ -510,6 +510,7 @@ class RMSNormGated(CustomOp):
norm_before_gate: bool = False, norm_before_gate: bool = False,
device: torch.device | None = None, device: torch.device | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
activation: str = "swish",
): ):
"""Initialize RMSNormGated. """Initialize RMSNormGated.
@@ -524,10 +525,12 @@ class RMSNormGated(CustomOp):
If False and z is provided: out = norm(x * silu(z)) If False and z is provided: out = norm(x * silu(z))
device: Device to create parameters on device: Device to create parameters on
dtype: Data type for parameters dtype: Data type for parameters
activation: Activation function name for gating
""" """
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.activation = activation
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.group_size = group_size self.group_size = group_size