Add gentle RMSNorm: only clamps when values exceed unit norm
This commit is contained in:
@@ -107,7 +107,14 @@ class mHCBlock:
|
||||
mixed_residual = torch.einsum('tij,tjh->tjh', comb_mix, residual.float())
|
||||
post_term = post_mix.unsqueeze(-1) * F_out.unsqueeze(1).float()
|
||||
residual_next = mixed_residual + post_term
|
||||
return residual_next.bfloat16()
|
||||
# Gentle normalization: RMSNorm but preserving relative magnitudes
|
||||
# Only active to prevent runaway growth (MoE should handle most balance)
|
||||
_T = residual_next.shape[0]
|
||||
rn_f = residual_next.reshape(_T, self.n_hc, -1)
|
||||
rms = rn_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
# Scale RMS so unit norm ≈ 1.0, not squash to sqrt(d)
|
||||
scale = (rms * math.sqrt(rn_f.shape[-1])).clamp(max=1.0)
|
||||
return (rn_f * scale).bfloat16()
|
||||
|
||||
# =====================================================================
|
||||
# RoPE
|
||||
|
||||
Reference in New Issue
Block a user