diag: print SE global scales for first 3 layers

This commit is contained in:
2026-06-01 02:49:55 +00:00
parent 48d93a6d2e
commit b85fcf4d6f

View File

@@ -422,6 +422,7 @@ def moe_forward(x, li, moe_runner, se_runner, router, token_id):
shared_out = se_runner.run(x)
if li < 3:
print(f" L{li} MoE shared: |out|={shared_out.abs().max().item():.4f} has_nan={torch.isnan(shared_out).any().item()}", flush=True)
print(f" L{li} SE gsa: l1={se_runner._l1_activation_global_scale:.6f} l2={se_runner._l2_activation_global_scale:.6f} gsb: l1={se_runner._l1_gsb[0].item():.6f} l2={se_runner._l2_gsb[0].item():.6f}", flush=True)
return routed_out + shared_out
# =====================================================================