Add gentle RMSNorm: only clamps when values exceed unit norm

This commit is contained in:
2026-05-31 00:31:34 +00:00
parent dcbb74841a
commit 523b0e47b1

View File

@@ -107,7 +107,14 @@ class mHCBlock:
mixed_residual = torch.einsum('tij,tjh->tjh', comb_mix, residual.float())
post_term = post_mix.unsqueeze(-1) * F_out.unsqueeze(1).float()
residual_next = mixed_residual + post_term
return residual_next.bfloat16()
# Gentle normalization: RMSNorm but preserving relative magnitudes
# Only active to prevent runaway growth (MoE should handle most balance)
_T = residual_next.shape[0]
rn_f = residual_next.reshape(_T, self.n_hc, -1)
rms = rn_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
# Scale RMS so unit norm ≈ 1.0, not squash to sqrt(d)
scale = (rms * math.sqrt(rn_f.shape[-1])).clamp(max=1.0)
return (rn_f * scale).bfloat16()
# =====================================================================
# RoPE