diff --git a/tests/layertest.py b/tests/layertest.py index d6da390c..83e645d8 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -147,14 +147,17 @@ def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, inter fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate - fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) + # Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) + fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf) + fused_sf = fused_sf.permute(1, 0).contiguous() # → (K_sf, N) # Normalize dual global scales l1_max_gs = max(gate_gs, up_gs) if gate_gs != up_gs: fused_sf_f32 = fused_sf.float() - fused_sf_f32[:intermediate_size] *= (gate_gs / l1_max_gs) - fused_sf_f32[intermediate_size:] *= (up_gs / l1_max_gs) + # Gate is first intermediate cols, up is second (after transpose) + fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs) + fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs) fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) l1_fp4.append(fused_w_fp4) @@ -172,12 +175,12 @@ def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, inter # (intermediate//2, hidden) — K=intermediate packed, N=hidden l2_fp4.append(down_w_fp4) - l2_sf.append(down_sf) + l2_sf.append(down_sf.permute(1, 0).contiguous()) # (N, K_sf) → (K_sf, N) l2_gs.append(down_gs) else: # Expert 211 has no down_proj l2_fp4.append(torch.zeros(3072 // 2, 7168, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) - l2_sf.append(torch.ones(7168, 3072 // 16, dtype=torch.float8_e4m3fn, device=DEVICE)) + l2_sf.append(torch.ones(3072 // 16, 7168, dtype=torch.float8_e4m3fn, device=DEVICE)) # (K_sf, N) l2_gs.append(1.0) return { diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 02b700c5..9836456b 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -450,16 +450,18 @@ class DeepseekV4MegaMoEExperts(nn.Module): fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # shape: (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate - # Fuse block scales: (2*intermediate, hidden//16) = (N, K_sf) ✓ - fused_sf = torch.cat([gate_sf, up_sf], dim=0) + # Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) + fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf) + # Transpose to (K_sf, N) for assemble_scales_3d_side + fused_sf = fused_sf.permute(1, 0).contiguous() # Handle dual global scales: normalize to max, fold ratio into block scales l1_max_gs = max(gate_gs, up_gs) if gate_gs != up_gs: fused_sf_f32 = fused_sf.float() - # Gate is the first intermediate rows, up is the second - fused_sf_f32[:self.intermediate_size] *= (gate_gs / l1_max_gs) - fused_sf_f32[self.intermediate_size:] *= (up_gs / l1_max_gs) + # After transpose to (K_sf, N): gate is first intermediate cols, up is next + fused_sf_f32[:, :self.intermediate_size] *= (gate_gs / l1_max_gs) + fused_sf_f32[:, self.intermediate_size:] *= (up_gs / l1_max_gs) fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) l1_fp4.append(fused_w_fp4) @@ -476,7 +478,8 @@ class DeepseekV4MegaMoEExperts(nn.Module): down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # shape: (intermediate//2, hidden) — K=intermediate packed, N=hidden - # Block scales: (hidden, intermediate//16) = (N, K_sf) ✓ — already correct + # Block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) + down_sf = down_sf.permute(1, 0).contiguous() l2_fp4.append(down_w_fp4) l2_sf.append(down_sf)