diag: validate router output before MoE

This commit is contained in:
2026-06-01 01:27:16 +00:00
parent f5fa84016e
commit 7fbbdc5204

View File

@@ -380,6 +380,9 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
# =====================================================================
def moe_forward(x, li, moe_runner, se_runner, router, token_id):
topk_w, topk_ids = router(x, token_ids=token_id)
# Diag: validate router output before MoE
if topk_ids.max().item() >= 384 or topk_ids.min().item() < 0:
print(f" L{li} BAD topk_ids: min={topk_ids.min().item()} max={topk_ids.max().item()}", flush=True)
routed_out = moe_runner.run(x, topk_w, topk_ids)
shared_out = se_runner.run(x)
return routed_out + shared_out