diag: remove broken SE reference comparison, add gsa/gsb print
This commit is contained in:
@@ -421,20 +421,10 @@ def moe_forward(x, li, moe_runner, se_runner, router, token_id):
|
||||
print(f" L{li} MoE routed: |out|={routed_out.abs().max().item():.4f} has_nan={torch.isnan(routed_out).any().item()}", flush=True)
|
||||
shared_out = se_runner.run(x)
|
||||
if li < 3:
|
||||
# Compare shared expert with PyTorch reference
|
||||
se_pfx = f'model.layers.{li}.mlp.shared_experts'
|
||||
se_gw, se_gws, se_gws2, se_gisc = get_nvfp4_weight(layer_w.get(li, {}), se_pfx, 'gate_proj')
|
||||
se_uw, se_uws, se_uws2, se_uisc = get_nvfp4_weight(layer_w.get(li, {}), se_pfx, 'up_proj')
|
||||
if se_gw is not None and se_uw is not None:
|
||||
se_gate = do_nvfp4_linear(x, layer_w.get(li, {}), se_pfx, 'gate_proj')
|
||||
se_up = do_nvfp4_linear(x, layer_w.get(li, {}), se_pfx, 'up_proj')
|
||||
if se_gate is not None and se_up is not None:
|
||||
se_ref = torch.nn.functional.silu(se_gate) * se_up
|
||||
se_dw, se_dws, se_dws2, se_disc = get_nvfp4_weight(layer_w.get(li, {}), se_pfx, 'down_proj')
|
||||
if se_dw is not None:
|
||||
se_ref = do_nvfp4_linear(se_ref, layer_w.get(li, {}), se_pfx, 'down_proj')
|
||||
cos_se = torch.nn.functional.cosine_similarity(shared_out.flatten().float(), se_ref.flatten().float(), dim=0).item() if not torch.isnan(shared_out).any().item() else -1.0
|
||||
print(f" L{li} SE ref: |ref|={se_ref.abs().max().item():.4f} |prod|={'NaN' if torch.isnan(shared_out).any().item() else f'{shared_out.abs().max().item():.4f}'} cos={cos_se:.4f}", flush=True)
|
||||
has_nan = torch.isnan(shared_out).any().item()
|
||||
out_max = shared_out.abs().max().item() if not has_nan else float('nan')
|
||||
print(f" L{li} MoE shared: |out|={out_max:.4f} has_nan={has_nan}", flush=True)
|
||||
print(f" L{li} SE gsa: l1={se_runner._l1_activation_global_scale:.6f} l2={se_runner._l2_activation_global_scale:.6f} gsb: l1={se_runner._l1_gsb[0].item():.6f} l2={se_runner._l2_gsb[0].item():.6f}", flush=True)
|
||||
return routed_out + shared_out
|
||||
|
||||
# =====================================================================
|
||||
|
||||
Reference in New Issue
Block a user