fix: handle model. prefix normalization in checkpoint keys

This commit is contained in:
2026-05-16 02:18:52 +00:00
parent bdf9f31ae2
commit 55d9a24bf6

View File

@@ -56,6 +56,7 @@ def load_layer_tensors(model_dir, layer_idx):
"""Load all tensors for a specific layer from the checkpoint.
Returns dict of {key: tensor} for all keys matching the layer.
Keys are normalized to NOT have 'model.' prefix.
"""
key_to_shard = find_shards(model_dir)
layer_prefix = f"layers.{layer_idx}."
@@ -63,15 +64,17 @@ def load_layer_tensors(model_dir, layer_idx):
# Group by shard to minimize file opens
shard_to_keys = {}
for key, shard in key_to_shard.items():
if not key.startswith(layer_prefix):
# Normalize: strip 'model.' prefix if present
norm_key = key.removeprefix("model.")
if not norm_key.startswith(layer_prefix):
continue
shard_to_keys.setdefault(shard, []).append(key)
shard_to_keys.setdefault(shard, []).append((key, norm_key))
tensors = {}
for shard, keys in shard_to_keys.items():
with safe_open(shard, framework="pt") as f:
for key in keys:
tensors[key] = f.get_tensor(key)
for orig_key, norm_key in keys:
tensors[norm_key] = f.get_tensor(orig_key)
return tensors
@@ -118,17 +121,19 @@ def dequantize_mxfp4_weight(packed_uint8, scale_e8m0):
def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices):
"""Dequantize expert weights from original MXFP4 checkpoint.
Original checkpoint key format: layers.{L}.ffn.experts.{E}.{w1/w2/w3}.{weight/scale}
w1 = gate_proj, w3 = up_proj, w2 = down_proj
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
"""
experts = {}
for e in expert_indices:
expert = {}
for proj in ["gate_proj", "up_proj", "down_proj"]:
weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.scale"
for proj, shard in [("gate_proj", "w1"), ("up_proj", "w3"), ("down_proj", "w2")]:
weight_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.weight"
scale_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.scale"
if weight_key not in orig_tensors:
# Expert 211 has no down_proj
if proj == "down_proj" and e == 211:
continue
raise KeyError(f"Missing {weight_key}")
@@ -174,6 +179,8 @@ def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
"""Dequantize expert weights from NVFP4 checkpoint.
NVFP4 checkpoint key format: layers.{L}.mlp.experts.{E}.{gate_proj/up_proj/down_proj}.{weight/weight_scale/weight_scale_2}
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
"""
experts = {}