diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index ff78f0886..72f42de06 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -510,6 +510,7 @@ class RMSNormGated(CustomOp): norm_before_gate: bool = False, device: torch.device | None = None, dtype: torch.dtype | None = None, + activation: str = "swish", ): """Initialize RMSNormGated. @@ -524,10 +525,12 @@ class RMSNormGated(CustomOp): If False and z is provided: out = norm(x * silu(z)) device: Device to create parameters on dtype: Data type for parameters + activation: Activation function name for gating """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps + self.activation = activation self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.group_size = group_size