From bdf9f31ae2622fe370b44deae4faba4bec377e05 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 02:17:13 +0000 Subject: [PATCH] fix: checkpoint keys don't have 'model.' prefix --- tests/layertest.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/tests/layertest.py b/tests/layertest.py index fc0f178b..bba40dab 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -52,21 +52,19 @@ def find_shards(model_dir): return key_to_shard -def load_layer_tensors(model_dir, layer_idx, prefix_filter=None): +def load_layer_tensors(model_dir, layer_idx): """Load all tensors for a specific layer from the checkpoint. Returns dict of {key: tensor} for all keys matching the layer. """ key_to_shard = find_shards(model_dir) - layer_prefix = f"model.layers.{layer_idx}." + layer_prefix = f"layers.{layer_idx}." # Group by shard to minimize file opens shard_to_keys = {} for key, shard in key_to_shard.items(): if not key.startswith(layer_prefix): continue - if prefix_filter and prefix_filter not in key: - continue shard_to_keys.setdefault(shard, []).append(key) tensors = {} @@ -126,8 +124,8 @@ def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices): for e in expert_indices: expert = {} for proj in ["gate_proj", "up_proj", "down_proj"]: - weight_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight" - scale_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.scale" + weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight" + scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.scale" if weight_key not in orig_tensors: # Expert 211 has no down_proj @@ -182,9 +180,9 @@ def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices): for e in expert_indices: expert = {} for proj in ["gate_proj", "up_proj", "down_proj"]: - weight_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight" - scale_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale" - gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2" + weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight" + scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale" + gs_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2" if weight_key not in nvfp4_tensors: if proj == "down_proj" and e == 211: @@ -279,12 +277,12 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper for e in unique_experts: # L1: gate_proj + up_proj fused - gate_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" - gate_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale" - gate_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2" - up_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight" - up_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale" - up_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2" + gate_w_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + gate_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale" + gate_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2" + up_w_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight" + up_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale" + up_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2" gate_w = nvfp4_tensors[gate_w_key].view(torch.int8).to(DEVICE) gate_sf = nvfp4_tensors[gate_sf_key].to(DEVICE) @@ -303,11 +301,11 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper l1_global_scales.append(l1_gs) # L2: down_proj - down_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" if down_w_key in nvfp4_tensors: down_w = nvfp4_tensors[down_w_key].view(torch.int8).to(DEVICE) - down_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale" - down_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2" + down_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale" + down_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2" down_sf = nvfp4_tensors[down_sf_key].to(DEVICE) down_gs = nvfp4_tensors[down_gs_key].item() else: @@ -436,7 +434,7 @@ def main(): # Verify dtype of weight_scale (should be float8_e4m3fn, NOT float8_e8m0fnu) for e in expert_indices[:1]: for proj in ["gate_proj", "up_proj", "down_proj"]: - key = f"model.layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale" + key = f"layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale" if key in nvfp4_tensors: dt = nvfp4_tensors[key].dtype print(f" {proj}.weight_scale dtype = {dt} {'✓ E4M3' if dt == torch.float8_e4m3fn else '✗ WRONG (expected float8_e4m3fn)'}")