Fix: correct intermediate_size=3072, weight key prefix, dequantize shapes
This commit is contained in:
@@ -27,7 +27,7 @@ MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
LAYER_IDX = 0
|
||||
NUM_EXPERTS = 48 # local experts per rank (256/8=32, but model uses 48)
|
||||
HIDDEN_SIZE = 7168
|
||||
INTERMEDIATE_SIZE = 18432
|
||||
INTERMEDIATE_SIZE = 3072 # per routed expert (18432 is shared expert)
|
||||
NUM_TOKENS = 8
|
||||
TOP_K = 6
|
||||
SWIGLU_LIMIT = 10.0
|
||||
@@ -46,13 +46,15 @@ ENABLE_FULL_RUNNER = True
|
||||
# ============================================================
|
||||
def load_layer_tensors(model_dir, layer_idx):
|
||||
tensors = {}
|
||||
pattern = os.path.join(model_dir, f"layers.{layer_idx}.mlp.experts.*")
|
||||
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
|
||||
from safetensors.torch import load_file
|
||||
data = load_file(sf)
|
||||
for k, v in data.items():
|
||||
if f"layers.{layer_idx}." in k:
|
||||
tensors[k] = v
|
||||
# Match both "layers.X." and "model.layers.X."
|
||||
if f"layers.{layer_idx}." in k and "mlp.experts" in k:
|
||||
# Normalize: strip "model." prefix if present
|
||||
norm_key = k.removeprefix("model.")
|
||||
tensors[norm_key] = v
|
||||
return tensors
|
||||
|
||||
|
||||
@@ -96,8 +98,8 @@ def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate
|
||||
l2_sf.append(down_sf.permute(1, 0).contiguous())
|
||||
l2_gs.append(down_gs)
|
||||
else:
|
||||
l2_fp4.append(torch.zeros(intermediate_size // 2, hidden_size, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
|
||||
l2_sf.append(torch.ones(intermediate_size // 16, hidden_size, dtype=torch.float8_e4m3fn, device=DEVICE))
|
||||
l2_fp4.append(torch.zeros(intermediate_size // 2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
|
||||
l2_sf.append(torch.ones(intermediate_size // 16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE))
|
||||
l2_gs.append(1.0)
|
||||
|
||||
return {
|
||||
@@ -107,29 +109,28 @@ def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate
|
||||
|
||||
|
||||
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
|
||||
"""Dequantize NVFP4 weight to BF16 for reference computation."""
|
||||
# FP4 lookup table
|
||||
"""Dequantize NVFP4 weight to BF16 for reference computation.
|
||||
|
||||
packed_uint8: (N, K_packed) where K_packed = K//2
|
||||
scale_e4m3: (N, K_sf) where K_sf = K//16
|
||||
Returns: (N, K) BF16
|
||||
"""
|
||||
lut = torch.tensor([
|
||||
0., 0.5, 1., 1.5, 2., 3., 4., 6.,
|
||||
-0., -0.5, -1., -1.5, -2., -3., -4., -6.
|
||||
], dtype=torch.float32)
|
||||
device = packed_uint8.device
|
||||
lut = lut.to(device)
|
||||
], dtype=torch.float32, device=packed_uint8.device)
|
||||
|
||||
lower = lut[(packed_uint8 & 0x0F).long()]
|
||||
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
||||
out_features = packed_uint8.shape[0]
|
||||
in_features = packed_uint8.shape[1] * 2
|
||||
N = packed_uint8.shape[0]
|
||||
K = packed_uint8.shape[1] * 2
|
||||
|
||||
bf16_vals = torch.stack([lower, upper], dim=-1).reshape(N, K)
|
||||
|
||||
# scale_e4m3 is (N, K_sf) where K_sf = K//16
|
||||
K_sf = scale_e4m3.shape[1]
|
||||
scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1) # (N, K)
|
||||
|
||||
bf16_vals = torch.stack([lower, upper], dim=-1).reshape(out_features, in_features)
|
||||
scale_2d = scale_e4m3.float().reshape(-1, 1).expand(-1, in_features // scale_e4m3.shape[0] if scale_e4m3.shape[0] < in_features else 1)
|
||||
# scale is (K_sf, N), expand to match (K, N) where K_sf = K/16
|
||||
K, N = packed_uint8.shape[0], packed_uint8.shape[1] * 2
|
||||
K_sf = scale_e4m3.shape[0]
|
||||
if K_sf != K:
|
||||
scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=0)
|
||||
else:
|
||||
scale_2d = scale_e4m3.float()
|
||||
dequant = bf16_vals * scale_2d * global_scale
|
||||
return dequant.to(torch.bfloat16)
|
||||
|
||||
@@ -195,8 +196,8 @@ def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices,
|
||||
gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
|
||||
up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
|
||||
|
||||
gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf.T if gate_sf.shape[0] == gate_w.shape[1] else gate_sf, gate_gs)
|
||||
up_bf16 = dequantize_nvfp4_weight(up_w, up_sf.T if up_sf.shape[0] == up_w.shape[1] else up_sf, up_gs)
|
||||
gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs) # (intermediate, hidden)
|
||||
up_bf16 = dequantize_nvfp4_weight(up_w, up_sf, up_gs) # (intermediate, hidden)
|
||||
|
||||
gate = x @ gate_bf16.T # (T, intermediate)
|
||||
up = x @ up_bf16.T # (T, intermediate)
|
||||
@@ -218,7 +219,7 @@ def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices,
|
||||
down_w = nvfp4_tensors[down_key].to(DEVICE)
|
||||
down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
|
||||
down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
|
||||
down_bf16 = dequantize_nvfp4_weight(down_w, down_sf.T if down_sf.shape[0] == down_w.shape[1] else down_sf, down_gs)
|
||||
down_bf16 = dequantize_nvfp4_weight(down_w, down_sf, down_gs) # (hidden, intermediate)
|
||||
l2_out = activated @ down_bf16.T # (T, H)
|
||||
else:
|
||||
l2_out = activated[:, :HIDDEN_SIZE]
|
||||
|
||||
Reference in New Issue
Block a user