diff --git a/tests/layertest.py b/tests/layertest.py index bba40dab..56beeb00 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -56,6 +56,7 @@ def load_layer_tensors(model_dir, layer_idx): """Load all tensors for a specific layer from the checkpoint. Returns dict of {key: tensor} for all keys matching the layer. + Keys are normalized to NOT have 'model.' prefix. """ key_to_shard = find_shards(model_dir) layer_prefix = f"layers.{layer_idx}." @@ -63,15 +64,17 @@ def load_layer_tensors(model_dir, layer_idx): # Group by shard to minimize file opens shard_to_keys = {} for key, shard in key_to_shard.items(): - if not key.startswith(layer_prefix): + # Normalize: strip 'model.' prefix if present + norm_key = key.removeprefix("model.") + if not norm_key.startswith(layer_prefix): continue - shard_to_keys.setdefault(shard, []).append(key) + shard_to_keys.setdefault(shard, []).append((key, norm_key)) tensors = {} for shard, keys in shard_to_keys.items(): with safe_open(shard, framework="pt") as f: - for key in keys: - tensors[key] = f.get_tensor(key) + for orig_key, norm_key in keys: + tensors[norm_key] = f.get_tensor(orig_key) return tensors @@ -118,17 +121,19 @@ def dequantize_mxfp4_weight(packed_uint8, scale_e8m0): def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices): """Dequantize expert weights from original MXFP4 checkpoint. + Original checkpoint key format: layers.{L}.ffn.experts.{E}.{w1/w2/w3}.{weight/scale} + w1 = gate_proj, w3 = up_proj, w2 = down_proj + Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16. """ experts = {} for e in expert_indices: expert = {} - for proj in ["gate_proj", "up_proj", "down_proj"]: - weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight" - scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.scale" + for proj, shard in [("gate_proj", "w1"), ("up_proj", "w3"), ("down_proj", "w2")]: + weight_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.weight" + scale_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.scale" if weight_key not in orig_tensors: - # Expert 211 has no down_proj if proj == "down_proj" and e == 211: continue raise KeyError(f"Missing {weight_key}") @@ -174,6 +179,8 @@ def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices): """Dequantize expert weights from NVFP4 checkpoint. + NVFP4 checkpoint key format: layers.{L}.mlp.experts.{E}.{gate_proj/up_proj/down_proj}.{weight/weight_scale/weight_scale_2} + Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16. """ experts = {}