Add emergency RMSNorm after residuals (missing mHC fallback)
Without mHC, values explode to 761K after first layer. Added per-residual RMSNorm + BF16 clamp to keep values bounded. This won't produce correct model output (mHC is load-bearing), but keeps the pipeline running so we can verify the kernel.
This commit is contained in:
@@ -294,7 +294,14 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
attn_proj = nvfp4_linear(grouped_flat, ob_w, ob_s, ob_s2) # (1, H)
|
||||
|
||||
# ---- Residual ----
|
||||
# Without mHC, values explode. Add RMSNorm as a fallback.
|
||||
x = x + attn_proj
|
||||
# Emergency: clip to BF16 range to prevent NaN propagation
|
||||
x = x.clamp(-65504, 65504)
|
||||
# Per-layer norm (not in real model — mHC handles this)
|
||||
x_f = x.float()
|
||||
rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x = (x_f * rms).bfloat16()
|
||||
|
||||
# ---- FFN (shared expert only for baseline) ----
|
||||
# No separate FFN norm in DSV4 — mHC handles it
|
||||
@@ -320,6 +327,10 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
w[f"{se_pre}.down_proj.weight_scale_2"],
|
||||
)
|
||||
x = x + ffn_out
|
||||
x = x.clamp(-65504, 65504)
|
||||
x_f = x.float()
|
||||
rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x = (x_f * rms).bfloat16()
|
||||
# Note: for full model, also need routed experts + scaling
|
||||
else:
|
||||
print(f" L{li}: no shared expert weights, skipping FFN")
|
||||
|
||||
Reference in New Issue
Block a user