diag: fix SE scale print (cast to float first)
This commit is contained in:
@@ -429,8 +429,8 @@ def moe_forward(x, li, moe_runner, se_runner, router, token_id):
|
||||
wb = se_runner._l1_mat_b.view(torch.uint8)
|
||||
print(f" L{li} SE l1 weight: shape={list(se_runner._l1_mat_b.shape)} dtype={se_runner._l1_mat_b.dtype} uint8_range=[{wb.min().item()},{wb.max().item()}]", flush=True)
|
||||
if hasattr(se_runner, '_l1_scale_b') and se_runner._l1_scale_b is not None:
|
||||
sb = se_runner._l1_scale_b
|
||||
print(f" L{li} SE l1 scale: shape={list(sb.shape)} dtype={sb.dtype} range=[{sb.min().item():.6f},{sb.max().item():.6f}] has_nan={torch.isnan(sb).any().item()}", flush=True)
|
||||
sb = se_runner._l1_scale_b.float()
|
||||
print(f" L{li} SE l1 scale: shape={list(se_runner._l1_scale_b.shape)} dtype={se_runner._l1_scale_b.dtype} float_range=[{sb.min().item():.6f},{sb.max().item():.6f}] has_nan={torch.isnan(sb).any().item()}", flush=True)
|
||||
print(f" L{li} SE gsa: l1={se_runner._l1_activation_global_scale:.6f} l2={se_runner._l2_activation_global_scale:.6f} gsb: l1={se_runner._l1_gsb[0].item():.6f} l2={se_runner._l2_gsb[0].item():.6f}", flush=True)
|
||||
return routed_out + shared_out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user