From 7fff5fd39b24f0d3a7dad005c85fbb38c2d71790 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 21:18:20 +0000 Subject: [PATCH] Fix: correct intermediate_size=3072, weight key prefix, dequantize shapes --- tests/test_pipeline_real_weights.py | 51 +++++++++++++++-------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index 8d70ea2d..a7174f39 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -27,7 +27,7 @@ MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" LAYER_IDX = 0 NUM_EXPERTS = 48 # local experts per rank (256/8=32, but model uses 48) HIDDEN_SIZE = 7168 -INTERMEDIATE_SIZE = 18432 +INTERMEDIATE_SIZE = 3072 # per routed expert (18432 is shared expert) NUM_TOKENS = 8 TOP_K = 6 SWIGLU_LIMIT = 10.0 @@ -46,13 +46,15 @@ ENABLE_FULL_RUNNER = True # ============================================================ def load_layer_tensors(model_dir, layer_idx): tensors = {} - pattern = os.path.join(model_dir, f"layers.{layer_idx}.mlp.experts.*") for sf in glob.glob(os.path.join(model_dir, "*.safetensors")): from safetensors.torch import load_file data = load_file(sf) for k, v in data.items(): - if f"layers.{layer_idx}." in k: - tensors[k] = v + # Match both "layers.X." and "model.layers.X." + if f"layers.{layer_idx}." in k and "mlp.experts" in k: + # Normalize: strip "model." prefix if present + norm_key = k.removeprefix("model.") + tensors[norm_key] = v return tensors @@ -96,8 +98,8 @@ def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate l2_sf.append(down_sf.permute(1, 0).contiguous()) l2_gs.append(down_gs) else: - l2_fp4.append(torch.zeros(intermediate_size // 2, hidden_size, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) - l2_sf.append(torch.ones(intermediate_size // 16, hidden_size, dtype=torch.float8_e4m3fn, device=DEVICE)) + l2_fp4.append(torch.zeros(intermediate_size // 2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) + l2_sf.append(torch.ones(intermediate_size // 16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE)) l2_gs.append(1.0) return { @@ -107,29 +109,28 @@ def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): - """Dequantize NVFP4 weight to BF16 for reference computation.""" - # FP4 lookup table + """Dequantize NVFP4 weight to BF16 for reference computation. + + packed_uint8: (N, K_packed) where K_packed = K//2 + scale_e4m3: (N, K_sf) where K_sf = K//16 + Returns: (N, K) BF16 + """ lut = torch.tensor([ 0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6. - ], dtype=torch.float32) - device = packed_uint8.device - lut = lut.to(device) + ], dtype=torch.float32, device=packed_uint8.device) lower = lut[(packed_uint8 & 0x0F).long()] upper = lut[((packed_uint8 >> 4) & 0x0F).long()] - out_features = packed_uint8.shape[0] - in_features = packed_uint8.shape[1] * 2 + N = packed_uint8.shape[0] + K = packed_uint8.shape[1] * 2 + + bf16_vals = torch.stack([lower, upper], dim=-1).reshape(N, K) + + # scale_e4m3 is (N, K_sf) where K_sf = K//16 + K_sf = scale_e4m3.shape[1] + scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1) # (N, K) - bf16_vals = torch.stack([lower, upper], dim=-1).reshape(out_features, in_features) - scale_2d = scale_e4m3.float().reshape(-1, 1).expand(-1, in_features // scale_e4m3.shape[0] if scale_e4m3.shape[0] < in_features else 1) - # scale is (K_sf, N), expand to match (K, N) where K_sf = K/16 - K, N = packed_uint8.shape[0], packed_uint8.shape[1] * 2 - K_sf = scale_e4m3.shape[0] - if K_sf != K: - scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=0) - else: - scale_2d = scale_e4m3.float() dequant = bf16_vals * scale_2d * global_scale return dequant.to(torch.bfloat16) @@ -195,8 +196,8 @@ def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices, gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() - gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf.T if gate_sf.shape[0] == gate_w.shape[1] else gate_sf, gate_gs) - up_bf16 = dequantize_nvfp4_weight(up_w, up_sf.T if up_sf.shape[0] == up_w.shape[1] else up_sf, up_gs) + gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs) # (intermediate, hidden) + up_bf16 = dequantize_nvfp4_weight(up_w, up_sf, up_gs) # (intermediate, hidden) gate = x @ gate_bf16.T # (T, intermediate) up = x @ up_bf16.T # (T, intermediate) @@ -218,7 +219,7 @@ def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices, down_w = nvfp4_tensors[down_key].to(DEVICE) down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() - down_bf16 = dequantize_nvfp4_weight(down_w, down_sf.T if down_sf.shape[0] == down_w.shape[1] else down_sf, down_gs) + down_bf16 = dequantize_nvfp4_weight(down_w, down_sf, down_gs) # (hidden, intermediate) l2_out = activated @ down_bf16.T # (T, H) else: l2_out = activated[:, :HIDDEN_SIZE]