can we see the wt in?
This commit is contained in:
@@ -79,6 +79,11 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_weight, l1_weight_scale = l1_tuple
|
||||
l2_weight, l2_weight_scale = l2_tuple
|
||||
|
||||
# DEBUG: check weights BEFORE transform
|
||||
print(f"[WT-IN] l1_w shape={l1_weight.shape} absmax={l1_weight.view(torch.int8).abs().max().item()} "
|
||||
f"l1_sf shape={l1_weight_scale.shape} sf_absmax={l1_weight_scale.view(torch.uint8).abs().max().item()} "
|
||||
f"l2_w shape={l2_weight.shape} absmax={l2_weight.view(torch.int8).abs().max().item()}")
|
||||
|
||||
# Fold global scales into block scales
|
||||
# The logical_widths branch was wrong: it treated gs as per-projection
|
||||
# scalars and only used experts 0 and 1's scales for ALL experts.
|
||||
@@ -107,4 +112,8 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_sf_out = l1_sf_out.transpose(-2, -1).contiguous()
|
||||
l2_sf_out = l2_sf_out.transpose(-2, -1).contiguous()
|
||||
|
||||
# DEBUG: check weights AFTER transform
|
||||
print(f"[WT-OUT] l1_w shape={l1_weight_out.shape} absmax={l1_weight_out.view(torch.int8).abs().max().item()} "
|
||||
f"l1_sf shape={l1_sf_out.shape} sf_absmax={l1_sf_out.view(torch.uint8).abs().max().item()}")
|
||||
|
||||
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
|
||||
|
||||
Reference in New Issue
Block a user