fix: transpose checkpoint block scales (N,K_sf)→(K_sf,N) for bridge
The bridge's assemble_scales_3d_side expects (K_sf, N) input and transposes to (N, K_sf) internally before swizzling. The checkpoint stores scales as (N, K_sf). Without this transpose, the kernel was reading completely wrong scale data — cosine dropped to 0.713. Also fixed dual global scale normalization: after transpose, gate/up are along dim 1 (columns), not dim 0 (rows).
This commit is contained in:
@@ -147,14 +147,17 @@ def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, inter
|
||||
fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
# (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate
|
||||
|
||||
fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16)
|
||||
# Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N)
|
||||
fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf)
|
||||
fused_sf = fused_sf.permute(1, 0).contiguous() # → (K_sf, N)
|
||||
|
||||
# Normalize dual global scales
|
||||
l1_max_gs = max(gate_gs, up_gs)
|
||||
if gate_gs != up_gs:
|
||||
fused_sf_f32 = fused_sf.float()
|
||||
fused_sf_f32[:intermediate_size] *= (gate_gs / l1_max_gs)
|
||||
fused_sf_f32[intermediate_size:] *= (up_gs / l1_max_gs)
|
||||
# Gate is first intermediate cols, up is second (after transpose)
|
||||
fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs)
|
||||
fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs)
|
||||
fused_sf = fused_sf_f32.to(torch.float8_e4m3fn)
|
||||
|
||||
l1_fp4.append(fused_w_fp4)
|
||||
@@ -172,12 +175,12 @@ def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, inter
|
||||
# (intermediate//2, hidden) — K=intermediate packed, N=hidden
|
||||
|
||||
l2_fp4.append(down_w_fp4)
|
||||
l2_sf.append(down_sf)
|
||||
l2_sf.append(down_sf.permute(1, 0).contiguous()) # (N, K_sf) → (K_sf, N)
|
||||
l2_gs.append(down_gs)
|
||||
else:
|
||||
# Expert 211 has no down_proj
|
||||
l2_fp4.append(torch.zeros(3072 // 2, 7168, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
|
||||
l2_sf.append(torch.ones(7168, 3072 // 16, dtype=torch.float8_e4m3fn, device=DEVICE))
|
||||
l2_sf.append(torch.ones(3072 // 16, 7168, dtype=torch.float8_e4m3fn, device=DEVICE)) # (K_sf, N)
|
||||
l2_gs.append(1.0)
|
||||
|
||||
return {
|
||||
|
||||
@@ -450,16 +450,18 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
# shape: (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate
|
||||
|
||||
# Fuse block scales: (2*intermediate, hidden//16) = (N, K_sf) ✓
|
||||
fused_sf = torch.cat([gate_sf, up_sf], dim=0)
|
||||
# Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N)
|
||||
fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf)
|
||||
# Transpose to (K_sf, N) for assemble_scales_3d_side
|
||||
fused_sf = fused_sf.permute(1, 0).contiguous()
|
||||
|
||||
# Handle dual global scales: normalize to max, fold ratio into block scales
|
||||
l1_max_gs = max(gate_gs, up_gs)
|
||||
if gate_gs != up_gs:
|
||||
fused_sf_f32 = fused_sf.float()
|
||||
# Gate is the first intermediate rows, up is the second
|
||||
fused_sf_f32[:self.intermediate_size] *= (gate_gs / l1_max_gs)
|
||||
fused_sf_f32[self.intermediate_size:] *= (up_gs / l1_max_gs)
|
||||
# After transpose to (K_sf, N): gate is first intermediate cols, up is next
|
||||
fused_sf_f32[:, :self.intermediate_size] *= (gate_gs / l1_max_gs)
|
||||
fused_sf_f32[:, self.intermediate_size:] *= (up_gs / l1_max_gs)
|
||||
fused_sf = fused_sf_f32.to(torch.float8_e4m3fn)
|
||||
|
||||
l1_fp4.append(fused_w_fp4)
|
||||
@@ -476,7 +478,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
# shape: (intermediate//2, hidden) — K=intermediate packed, N=hidden
|
||||
|
||||
# Block scales: (hidden, intermediate//16) = (N, K_sf) ✓ — already correct
|
||||
# Block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N)
|
||||
down_sf = down_sf.permute(1, 0).contiguous()
|
||||
|
||||
l2_fp4.append(down_w_fp4)
|
||||
l2_sf.append(down_sf)
|
||||
|
||||
Reference in New Issue
Block a user