diff --git a/single_shot_inference.py b/single_shot_inference.py index e552e7ba..f789520a 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -421,20 +421,10 @@ def moe_forward(x, li, moe_runner, se_runner, router, token_id): print(f" L{li} MoE routed: |out|={routed_out.abs().max().item():.4f} has_nan={torch.isnan(routed_out).any().item()}", flush=True) shared_out = se_runner.run(x) if li < 3: - # Compare shared expert with PyTorch reference - se_pfx = f'model.layers.{li}.mlp.shared_experts' - se_gw, se_gws, se_gws2, se_gisc = get_nvfp4_weight(layer_w.get(li, {}), se_pfx, 'gate_proj') - se_uw, se_uws, se_uws2, se_uisc = get_nvfp4_weight(layer_w.get(li, {}), se_pfx, 'up_proj') - if se_gw is not None and se_uw is not None: - se_gate = do_nvfp4_linear(x, layer_w.get(li, {}), se_pfx, 'gate_proj') - se_up = do_nvfp4_linear(x, layer_w.get(li, {}), se_pfx, 'up_proj') - if se_gate is not None and se_up is not None: - se_ref = torch.nn.functional.silu(se_gate) * se_up - se_dw, se_dws, se_dws2, se_disc = get_nvfp4_weight(layer_w.get(li, {}), se_pfx, 'down_proj') - if se_dw is not None: - se_ref = do_nvfp4_linear(se_ref, layer_w.get(li, {}), se_pfx, 'down_proj') - cos_se = torch.nn.functional.cosine_similarity(shared_out.flatten().float(), se_ref.flatten().float(), dim=0).item() if not torch.isnan(shared_out).any().item() else -1.0 - print(f" L{li} SE ref: |ref|={se_ref.abs().max().item():.4f} |prod|={'NaN' if torch.isnan(shared_out).any().item() else f'{shared_out.abs().max().item():.4f}'} cos={cos_se:.4f}", flush=True) + has_nan = torch.isnan(shared_out).any().item() + out_max = shared_out.abs().max().item() if not has_nan else float('nan') + print(f" L{li} MoE shared: |out|={out_max:.4f} has_nan={has_nan}", flush=True) + print(f" L{li} SE gsa: l1={se_runner._l1_activation_global_scale:.6f} l2={se_runner._l2_activation_global_scale:.6f} gsb: l1={se_runner._l1_gsb[0].item():.6f} l2={se_runner._l2_gsb[0].item():.6f}", flush=True) return routed_out + shared_out # =====================================================================