can we see the wt in?

This commit is contained in:
2026-05-15 05:41:12 +00:00
parent 6a4f52cedc
commit c5d800f133

View File

@@ -79,6 +79,11 @@ def transform_nvfp4_weights_for_mega_moe(
l1_weight, l1_weight_scale = l1_tuple
l2_weight, l2_weight_scale = l2_tuple
# DEBUG: check weights BEFORE transform
print(f"[WT-IN] l1_w shape={l1_weight.shape} absmax={l1_weight.view(torch.int8).abs().max().item()} "
f"l1_sf shape={l1_weight_scale.shape} sf_absmax={l1_weight_scale.view(torch.uint8).abs().max().item()} "
f"l2_w shape={l2_weight.shape} absmax={l2_weight.view(torch.int8).abs().max().item()}")
# Fold global scales into block scales
# The logical_widths branch was wrong: it treated gs as per-projection
# scalars and only used experts 0 and 1's scales for ALL experts.
@@ -107,4 +112,8 @@ def transform_nvfp4_weights_for_mega_moe(
l1_sf_out = l1_sf_out.transpose(-2, -1).contiguous()
l2_sf_out = l2_sf_out.transpose(-2, -1).contiguous()
# DEBUG: check weights AFTER transform
print(f"[WT-OUT] l1_w shape={l1_weight_out.shape} absmax={l1_weight_out.view(torch.int8).abs().max().item()} "
f"l1_sf shape={l1_sf_out.shape} sf_absmax={l1_sf_out.view(torch.uint8).abs().max().item()}")
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)