From d22dae2df3c3fa8b3bb51c5de0d37b74d2362bc9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 08:28:40 +0000 Subject: [PATCH] were getting close --- src/nvfp4_megamoe_kernel/weight_transform.py | 26 ++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 6794a39d..c56345df 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -105,6 +105,32 @@ def transform_nvfp4_weights_for_mega_moe( print(f"[SF-DEBUG] gate gs unique={torch.unique(l1_gs_raw[:, 0]).numel()} " f"up gs unique={torch.unique(l1_gs_raw[:, 1]).numel()}") + # DEBUG: check L2 scales + l2_sf_f32_raw = l2_weight_scale.to(torch.float32) + l2_gs_raw = l2_weight_scale_2.to(torch.float32) if l2_weight_scale_2 is not None else None + if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_l2', False): + transform_nvfp4_weights_for_mega_moe._sf_debug_l2 = True + print(f"[SF-DEBUG-L2] raw l2_sf dtype={l2_weight_scale.dtype} range=[{l2_sf_f32_raw.min().item():.4e}, {l2_sf_f32_raw.max().item():.4e}] " + f"unique_raw={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}") + if l2_gs_raw is not None: + print(f"[SF-DEBUG-L2] l2_gs dtype={l2_weight_scale_2.dtype} shape={tuple(l2_weight_scale_2.shape)} " + f"range=[{l2_gs_raw.min().item():.4e}, {l2_gs_raw.max().item():.4e}] " + f"unique_gs={torch.unique(l2_gs_raw).numel()}") + + # Post-fold diagnostics — one-time + if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_fold', False): + transform_nvfp4_weights_for_mega_moe._sf_debug_fold = True + l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2) if l1_weight_scale_2 is not None else l1_weight_scale.to(torch.float32) + l1_sf_out_check = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) + l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2) if l2_weight_scale_2 is not None else l2_weight_scale.to(torch.float32) + l2_sf_out_check = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) + print(f"[SF-FOLD] l1 pre-fold unique_u8={torch.unique(l1_weight_scale.view(torch.uint8)).numel()} " + f"post-fold unique_u8={torch.unique(l1_sf_out_check.view(torch.uint8)).numel()} " + f"range=[{l1_sf_folded.min().item():.4e}, {l1_sf_folded.max().item():.4e}]") + print(f"[SF-FOLD] l2 pre-fold unique_u8={torch.unique(l2_weight_scale.view(torch.uint8)).numel()} " + f"post-fold unique_u8={torch.unique(l2_sf_out_check.view(torch.uint8)).numel()} " + f"range=[{l2_sf_folded.min().item():.4e}, {l2_sf_folded.max().item():.4e}]") + # Fold global scales into block scales # The logical_widths branch was wrong: it treated gs as per-projection # scalars and only used experts 0 and 1's scales for ALL experts.