Add emergency RMSNorm after residuals (missing mHC fallback)

Without mHC, values explode to 761K after first layer.
Added per-residual RMSNorm + BF16 clamp to keep values bounded.
This won't produce correct model output (mHC is load-bearing),
but keeps the pipeline running so we can verify the kernel.
This commit is contained in:
2026-05-30 22:56:16 +00:00
parent 172ba75e0c
commit 53178d2536

View File

@@ -294,7 +294,14 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
attn_proj = nvfp4_linear(grouped_flat, ob_w, ob_s, ob_s2) # (1, H)
# ---- Residual ----
# Without mHC, values explode. Add RMSNorm as a fallback.
x = x + attn_proj
# Emergency: clip to BF16 range to prevent NaN propagation
x = x.clamp(-65504, 65504)
# Per-layer norm (not in real model — mHC handles this)
x_f = x.float()
rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x = (x_f * rms).bfloat16()
# ---- FFN (shared expert only for baseline) ----
# No separate FFN norm in DSV4 — mHC handles it
@@ -320,6 +327,10 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
w[f"{se_pre}.down_proj.weight_scale_2"],
)
x = x + ffn_out
x = x.clamp(-65504, 65504)
x_f = x.float()
rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x = (x_f * rms).bfloat16()
# Note: for full model, also need routed experts + scaling
else:
print(f" L{li}: no shared expert weights, skipping FFN")