Fix: correct intermediate_size=3072, weight key prefix, dequantize shapes

This commit is contained in:
2026-05-17 21:18:20 +00:00
parent 4ef345773d
commit 7fff5fd39b

View File

@@ -27,7 +27,7 @@ MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
LAYER_IDX = 0
NUM_EXPERTS = 48 # local experts per rank (256/8=32, but model uses 48)
HIDDEN_SIZE = 7168
INTERMEDIATE_SIZE = 18432
INTERMEDIATE_SIZE = 3072 # per routed expert (18432 is shared expert)
NUM_TOKENS = 8
TOP_K = 6
SWIGLU_LIMIT = 10.0
@@ -46,13 +46,15 @@ ENABLE_FULL_RUNNER = True
# ============================================================
def load_layer_tensors(model_dir, layer_idx):
tensors = {}
pattern = os.path.join(model_dir, f"layers.{layer_idx}.mlp.experts.*")
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
from safetensors.torch import load_file
data = load_file(sf)
for k, v in data.items():
if f"layers.{layer_idx}." in k:
tensors[k] = v
# Match both "layers.X." and "model.layers.X."
if f"layers.{layer_idx}." in k and "mlp.experts" in k:
# Normalize: strip "model." prefix if present
norm_key = k.removeprefix("model.")
tensors[norm_key] = v
return tensors
@@ -96,8 +98,8 @@ def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate
l2_sf.append(down_sf.permute(1, 0).contiguous())
l2_gs.append(down_gs)
else:
l2_fp4.append(torch.zeros(intermediate_size // 2, hidden_size, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
l2_sf.append(torch.ones(intermediate_size // 16, hidden_size, dtype=torch.float8_e4m3fn, device=DEVICE))
l2_fp4.append(torch.zeros(intermediate_size // 2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
l2_sf.append(torch.ones(intermediate_size // 16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE))
l2_gs.append(1.0)
return {
@@ -107,29 +109,28 @@ def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
"""Dequantize NVFP4 weight to BF16 for reference computation."""
# FP4 lookup table
"""Dequantize NVFP4 weight to BF16 for reference computation.
packed_uint8: (N, K_packed) where K_packed = K//2
scale_e4m3: (N, K_sf) where K_sf = K//16
Returns: (N, K) BF16
"""
lut = torch.tensor([
0., 0.5, 1., 1.5, 2., 3., 4., 6.,
-0., -0.5, -1., -1.5, -2., -3., -4., -6.
], dtype=torch.float32)
device = packed_uint8.device
lut = lut.to(device)
], dtype=torch.float32, device=packed_uint8.device)
lower = lut[(packed_uint8 & 0x0F).long()]
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
out_features = packed_uint8.shape[0]
in_features = packed_uint8.shape[1] * 2
N = packed_uint8.shape[0]
K = packed_uint8.shape[1] * 2
bf16_vals = torch.stack([lower, upper], dim=-1).reshape(N, K)
# scale_e4m3 is (N, K_sf) where K_sf = K//16
K_sf = scale_e4m3.shape[1]
scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1) # (N, K)
bf16_vals = torch.stack([lower, upper], dim=-1).reshape(out_features, in_features)
scale_2d = scale_e4m3.float().reshape(-1, 1).expand(-1, in_features // scale_e4m3.shape[0] if scale_e4m3.shape[0] < in_features else 1)
# scale is (K_sf, N), expand to match (K, N) where K_sf = K/16
K, N = packed_uint8.shape[0], packed_uint8.shape[1] * 2
K_sf = scale_e4m3.shape[0]
if K_sf != K:
scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=0)
else:
scale_2d = scale_e4m3.float()
dequant = bf16_vals * scale_2d * global_scale
return dequant.to(torch.bfloat16)
@@ -195,8 +196,8 @@ def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices,
gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf.T if gate_sf.shape[0] == gate_w.shape[1] else gate_sf, gate_gs)
up_bf16 = dequantize_nvfp4_weight(up_w, up_sf.T if up_sf.shape[0] == up_w.shape[1] else up_sf, up_gs)
gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs) # (intermediate, hidden)
up_bf16 = dequantize_nvfp4_weight(up_w, up_sf, up_gs) # (intermediate, hidden)
gate = x @ gate_bf16.T # (T, intermediate)
up = x @ up_bf16.T # (T, intermediate)
@@ -218,7 +219,7 @@ def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices,
down_w = nvfp4_tensors[down_key].to(DEVICE)
down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
down_bf16 = dequantize_nvfp4_weight(down_w, down_sf.T if down_sf.shape[0] == down_w.shape[1] else down_sf, down_gs)
down_bf16 = dequantize_nvfp4_weight(down_w, down_sf, down_gs) # (hidden, intermediate)
l2_out = activated @ down_bf16.T # (T, H)
else:
l2_out = activated[:, :HIDDEN_SIZE]