were getting close
This commit is contained in:
@@ -105,6 +105,32 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
print(f"[SF-DEBUG] gate gs unique={torch.unique(l1_gs_raw[:, 0]).numel()} "
|
||||
f"up gs unique={torch.unique(l1_gs_raw[:, 1]).numel()}")
|
||||
|
||||
# DEBUG: check L2 scales
|
||||
l2_sf_f32_raw = l2_weight_scale.to(torch.float32)
|
||||
l2_gs_raw = l2_weight_scale_2.to(torch.float32) if l2_weight_scale_2 is not None else None
|
||||
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_l2', False):
|
||||
transform_nvfp4_weights_for_mega_moe._sf_debug_l2 = True
|
||||
print(f"[SF-DEBUG-L2] raw l2_sf dtype={l2_weight_scale.dtype} range=[{l2_sf_f32_raw.min().item():.4e}, {l2_sf_f32_raw.max().item():.4e}] "
|
||||
f"unique_raw={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}")
|
||||
if l2_gs_raw is not None:
|
||||
print(f"[SF-DEBUG-L2] l2_gs dtype={l2_weight_scale_2.dtype} shape={tuple(l2_weight_scale_2.shape)} "
|
||||
f"range=[{l2_gs_raw.min().item():.4e}, {l2_gs_raw.max().item():.4e}] "
|
||||
f"unique_gs={torch.unique(l2_gs_raw).numel()}")
|
||||
|
||||
# Post-fold diagnostics — one-time
|
||||
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_fold', False):
|
||||
transform_nvfp4_weights_for_mega_moe._sf_debug_fold = True
|
||||
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2) if l1_weight_scale_2 is not None else l1_weight_scale.to(torch.float32)
|
||||
l1_sf_out_check = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2) if l2_weight_scale_2 is not None else l2_weight_scale.to(torch.float32)
|
||||
l2_sf_out_check = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
print(f"[SF-FOLD] l1 pre-fold unique_u8={torch.unique(l1_weight_scale.view(torch.uint8)).numel()} "
|
||||
f"post-fold unique_u8={torch.unique(l1_sf_out_check.view(torch.uint8)).numel()} "
|
||||
f"range=[{l1_sf_folded.min().item():.4e}, {l1_sf_folded.max().item():.4e}]")
|
||||
print(f"[SF-FOLD] l2 pre-fold unique_u8={torch.unique(l2_weight_scale.view(torch.uint8)).numel()} "
|
||||
f"post-fold unique_u8={torch.unique(l2_sf_out_check.view(torch.uint8)).numel()} "
|
||||
f"range=[{l2_sf_folded.min().item():.4e}, {l2_sf_folded.max().item():.4e}]")
|
||||
|
||||
# Fold global scales into block scales
|
||||
# The logical_widths branch was wrong: it treated gs as per-projection
|
||||
# scalars and only used experts 0 and 1's scales for ALL experts.
|
||||
|
||||
Reference in New Issue
Block a user