From a0ff8a3278acd32402d1deab2ce7669b352f0b72 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 03:43:30 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20transpose=20checkpoint=20block=20scales?= =?UTF-8?q?=20(N,K=5Fsf)=E2=86=92(K=5Fsf,N)=20for=20bridge?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bridge's assemble_scales_3d_side expects (K_sf, N) input and transposes to (N, K_sf) internally before swizzling. The checkpoint stores scales as (N, K_sf). Without this transpose, the kernel was reading completely wrong scale data — cosine dropped to 0.713. Also fixed dual global scale normalization: after transpose, gate/up are along dim 1 (columns), not dim 0 (rows). --- tests/layertest.py | 13 ++++++++----- vllm/patches/deepseek_v4.py | 15 +++++++++------ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/layertest.py b/tests/layertest.py index d6da390c..83e645d8 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -147,14 +147,17 @@ def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, inter fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate - fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) + # Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) + fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf) + fused_sf = fused_sf.permute(1, 0).contiguous() # → (K_sf, N) # Normalize dual global scales l1_max_gs = max(gate_gs, up_gs) if gate_gs != up_gs: fused_sf_f32 = fused_sf.float() - fused_sf_f32[:intermediate_size] *= (gate_gs / l1_max_gs) - fused_sf_f32[intermediate_size:] *= (up_gs / l1_max_gs) + # Gate is first intermediate cols, up is second (after transpose) + fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs) + fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs) fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) l1_fp4.append(fused_w_fp4) @@ -172,12 +175,12 @@ def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, inter # (intermediate//2, hidden) — K=intermediate packed, N=hidden l2_fp4.append(down_w_fp4) - l2_sf.append(down_sf) + l2_sf.append(down_sf.permute(1, 0).contiguous()) # (N, K_sf) → (K_sf, N) l2_gs.append(down_gs) else: # Expert 211 has no down_proj l2_fp4.append(torch.zeros(3072 // 2, 7168, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) - l2_sf.append(torch.ones(7168, 3072 // 16, dtype=torch.float8_e4m3fn, device=DEVICE)) + l2_sf.append(torch.ones(3072 // 16, 7168, dtype=torch.float8_e4m3fn, device=DEVICE)) # (K_sf, N) l2_gs.append(1.0) return { diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 02b700c5..9836456b 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -450,16 +450,18 @@ class DeepseekV4MegaMoEExperts(nn.Module): fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # shape: (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate - # Fuse block scales: (2*intermediate, hidden//16) = (N, K_sf) ✓ - fused_sf = torch.cat([gate_sf, up_sf], dim=0) + # Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) + fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf) + # Transpose to (K_sf, N) for assemble_scales_3d_side + fused_sf = fused_sf.permute(1, 0).contiguous() # Handle dual global scales: normalize to max, fold ratio into block scales l1_max_gs = max(gate_gs, up_gs) if gate_gs != up_gs: fused_sf_f32 = fused_sf.float() - # Gate is the first intermediate rows, up is the second - fused_sf_f32[:self.intermediate_size] *= (gate_gs / l1_max_gs) - fused_sf_f32[self.intermediate_size:] *= (up_gs / l1_max_gs) + # After transpose to (K_sf, N): gate is first intermediate cols, up is next + fused_sf_f32[:, :self.intermediate_size] *= (gate_gs / l1_max_gs) + fused_sf_f32[:, self.intermediate_size:] *= (up_gs / l1_max_gs) fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) l1_fp4.append(fused_w_fp4) @@ -476,7 +478,8 @@ class DeepseekV4MegaMoEExperts(nn.Module): down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # shape: (intermediate//2, hidden) — K=intermediate packed, N=hidden - # Block scales: (hidden, intermediate//16) = (N, K_sf) ✓ — already correct + # Block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) + down_sf = down_sf.permute(1, 0).contiguous() l2_fp4.append(down_w_fp4) l2_sf.append(down_sf)