diff --git a/single_shot_inference.py b/single_shot_inference.py index 9dff6bcc..371c3f32 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -107,7 +107,14 @@ class mHCBlock: mixed_residual = torch.einsum('tij,tjh->tjh', comb_mix, residual.float()) post_term = post_mix.unsqueeze(-1) * F_out.unsqueeze(1).float() residual_next = mixed_residual + post_term - return residual_next.bfloat16() + # Gentle normalization: RMSNorm but preserving relative magnitudes + # Only active to prevent runaway growth (MoE should handle most balance) + _T = residual_next.shape[0] + rn_f = residual_next.reshape(_T, self.n_hc, -1) + rms = rn_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + # Scale RMS so unit norm ≈ 1.0, not squash to sqrt(d) + scale = (rms * math.sqrt(rn_f.shape[-1])).clamp(max=1.0) + return (rn_f * scale).bfloat16() # ===================================================================== # RoPE