fix: transpose checkpoint block scales (N,K_sf)→(K_sf,N) for bridge

The bridge's assemble_scales_3d_side expects (K_sf, N) input and
transposes to (N, K_sf) internally before swizzling. The checkpoint
stores scales as (N, K_sf). Without this transpose, the kernel was
reading completely wrong scale data — cosine dropped to 0.713.

Also fixed dual global scale normalization: after transpose, gate/up
are along dim 1 (columns), not dim 0 (rows).
This commit is contained in:
2026-05-16 03:43:30 +00:00
parent 389453fbf4
commit a0ff8a3278
2 changed files with 17 additions and 11 deletions

View File

@@ -147,14 +147,17 @@ def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, inter
fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate
fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16)
# Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N)
fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf)
fused_sf = fused_sf.permute(1, 0).contiguous() # → (K_sf, N)
# Normalize dual global scales
l1_max_gs = max(gate_gs, up_gs)
if gate_gs != up_gs:
fused_sf_f32 = fused_sf.float()
fused_sf_f32[:intermediate_size] *= (gate_gs / l1_max_gs)
fused_sf_f32[intermediate_size:] *= (up_gs / l1_max_gs)
# Gate is first intermediate cols, up is second (after transpose)
fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs)
fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs)
fused_sf = fused_sf_f32.to(torch.float8_e4m3fn)
l1_fp4.append(fused_w_fp4)
@@ -172,12 +175,12 @@ def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, inter
# (intermediate//2, hidden) — K=intermediate packed, N=hidden
l2_fp4.append(down_w_fp4)
l2_sf.append(down_sf)
l2_sf.append(down_sf.permute(1, 0).contiguous()) # (N, K_sf) → (K_sf, N)
l2_gs.append(down_gs)
else:
# Expert 211 has no down_proj
l2_fp4.append(torch.zeros(3072 // 2, 7168, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
l2_sf.append(torch.ones(7168, 3072 // 16, dtype=torch.float8_e4m3fn, device=DEVICE))
l2_sf.append(torch.ones(3072 // 16, 7168, dtype=torch.float8_e4m3fn, device=DEVICE)) # (K_sf, N)
l2_gs.append(1.0)
return {

View File

@@ -450,16 +450,18 @@ class DeepseekV4MegaMoEExperts(nn.Module):
fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# shape: (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate
# Fuse block scales: (2*intermediate, hidden//16) = (N, K_sf) ✓
fused_sf = torch.cat([gate_sf, up_sf], dim=0)
# Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N)
fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf)
# Transpose to (K_sf, N) for assemble_scales_3d_side
fused_sf = fused_sf.permute(1, 0).contiguous()
# Handle dual global scales: normalize to max, fold ratio into block scales
l1_max_gs = max(gate_gs, up_gs)
if gate_gs != up_gs:
fused_sf_f32 = fused_sf.float()
# Gate is the first intermediate rows, up is the second
fused_sf_f32[:self.intermediate_size] *= (gate_gs / l1_max_gs)
fused_sf_f32[self.intermediate_size:] *= (up_gs / l1_max_gs)
# After transpose to (K_sf, N): gate is first intermediate cols, up is next
fused_sf_f32[:, :self.intermediate_size] *= (gate_gs / l1_max_gs)
fused_sf_f32[:, self.intermediate_size:] *= (up_gs / l1_max_gs)
fused_sf = fused_sf_f32.to(torch.float8_e4m3fn)
l1_fp4.append(fused_w_fp4)
@@ -476,7 +478,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# shape: (intermediate//2, hidden) — K=intermediate packed, N=hidden
# Block scales: (hidden, intermediate//16) = (N, K_sf) ✓ — already correct
# Block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N)
down_sf = down_sf.permute(1, 0).contiguous()
l2_fp4.append(down_w_fp4)
l2_sf.append(down_sf)