[Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-12-10 10:23:07 +08:00
committed by GitHub
parent 6d525288c1
commit d1f6d1c8af
3 changed files with 32 additions and 10 deletions

View File

@@ -20,6 +20,7 @@ class RMSNorm(CustomOp):
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
) -> None:
super().__init__()
@@ -27,7 +28,11 @@ class RMSNorm(CustomOp):
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size))
self.has_weight = has_weight
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
def forward_native(
self,
@@ -59,7 +64,9 @@ class RMSNorm(CustomOp):
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else: