[Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-12-10 10:23:07 +08:00
committed by GitHub
parent 6d525288c1
commit d1f6d1c8af
3 changed files with 32 additions and 10 deletions

View File

@@ -40,6 +40,7 @@ class MambaMixer(CustomOp):
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu"):
super().__init__()
@@ -105,14 +106,23 @@ class MambaMixer(CustomOp):
input_is_parallel=True,
)
self.dt_layernorm = RMSNorm(time_step_rank,
eps=rms_norm_eps) if use_rms_norm else None
self.dt_layernorm = RMSNorm(
time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.b_layernorm = RMSNorm(ssm_state_size,
eps=rms_norm_eps) if use_rms_norm else None
self.b_layernorm = RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.c_layernorm = RMSNorm(ssm_state_size,
eps=rms_norm_eps) if use_rms_norm else None
self.c_layernorm = RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,