were getting close

This commit is contained in:
2026-05-15 08:28:40 +00:00
parent d493193d25
commit d22dae2df3

View File

@@ -105,6 +105,32 @@ def transform_nvfp4_weights_for_mega_moe(
print(f"[SF-DEBUG] gate gs unique={torch.unique(l1_gs_raw[:, 0]).numel()} "
f"up gs unique={torch.unique(l1_gs_raw[:, 1]).numel()}")
# DEBUG: check L2 scales
l2_sf_f32_raw = l2_weight_scale.to(torch.float32)
l2_gs_raw = l2_weight_scale_2.to(torch.float32) if l2_weight_scale_2 is not None else None
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_l2', False):
transform_nvfp4_weights_for_mega_moe._sf_debug_l2 = True
print(f"[SF-DEBUG-L2] raw l2_sf dtype={l2_weight_scale.dtype} range=[{l2_sf_f32_raw.min().item():.4e}, {l2_sf_f32_raw.max().item():.4e}] "
f"unique_raw={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}")
if l2_gs_raw is not None:
print(f"[SF-DEBUG-L2] l2_gs dtype={l2_weight_scale_2.dtype} shape={tuple(l2_weight_scale_2.shape)} "
f"range=[{l2_gs_raw.min().item():.4e}, {l2_gs_raw.max().item():.4e}] "
f"unique_gs={torch.unique(l2_gs_raw).numel()}")
# Post-fold diagnostics — one-time
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_fold', False):
transform_nvfp4_weights_for_mega_moe._sf_debug_fold = True
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2) if l1_weight_scale_2 is not None else l1_weight_scale.to(torch.float32)
l1_sf_out_check = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2) if l2_weight_scale_2 is not None else l2_weight_scale.to(torch.float32)
l2_sf_out_check = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
print(f"[SF-FOLD] l1 pre-fold unique_u8={torch.unique(l1_weight_scale.view(torch.uint8)).numel()} "
f"post-fold unique_u8={torch.unique(l1_sf_out_check.view(torch.uint8)).numel()} "
f"range=[{l1_sf_folded.min().item():.4e}, {l1_sf_folded.max().item():.4e}]")
print(f"[SF-FOLD] l2 pre-fold unique_u8={torch.unique(l2_weight_scale.view(torch.uint8)).numel()} "
f"post-fold unique_u8={torch.unique(l2_sf_out_check.view(torch.uint8)).numel()} "
f"range=[{l2_sf_folded.min().item():.4e}, {l2_sf_folded.max().item():.4e}]")
# Fold global scales into block scales
# The logical_widths branch was wrong: it treated gs as per-projection
# scalars and only used experts 0 and 1's scales for ALL experts.