[Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -20,6 +20,7 @@ class RMSNorm(CustomOp):
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -27,7 +28,11 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon = eps
|
||||
self.variance_size_override = (None if var_hidden_size == hidden_size
|
||||
else var_hidden_size)
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.has_weight = has_weight
|
||||
|
||||
self.weight = torch.ones(hidden_size)
|
||||
if self.has_weight:
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@@ -59,7 +64,9 @@ class RMSNorm(CustomOp):
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
x = x.to(orig_dtype)
|
||||
if self.has_weight:
|
||||
x = x * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user