fix: handle model. prefix normalization in checkpoint keys
This commit is contained in:
@@ -56,6 +56,7 @@ def load_layer_tensors(model_dir, layer_idx):
|
||||
"""Load all tensors for a specific layer from the checkpoint.
|
||||
|
||||
Returns dict of {key: tensor} for all keys matching the layer.
|
||||
Keys are normalized to NOT have 'model.' prefix.
|
||||
"""
|
||||
key_to_shard = find_shards(model_dir)
|
||||
layer_prefix = f"layers.{layer_idx}."
|
||||
@@ -63,15 +64,17 @@ def load_layer_tensors(model_dir, layer_idx):
|
||||
# Group by shard to minimize file opens
|
||||
shard_to_keys = {}
|
||||
for key, shard in key_to_shard.items():
|
||||
if not key.startswith(layer_prefix):
|
||||
# Normalize: strip 'model.' prefix if present
|
||||
norm_key = key.removeprefix("model.")
|
||||
if not norm_key.startswith(layer_prefix):
|
||||
continue
|
||||
shard_to_keys.setdefault(shard, []).append(key)
|
||||
shard_to_keys.setdefault(shard, []).append((key, norm_key))
|
||||
|
||||
tensors = {}
|
||||
for shard, keys in shard_to_keys.items():
|
||||
with safe_open(shard, framework="pt") as f:
|
||||
for key in keys:
|
||||
tensors[key] = f.get_tensor(key)
|
||||
for orig_key, norm_key in keys:
|
||||
tensors[norm_key] = f.get_tensor(orig_key)
|
||||
|
||||
return tensors
|
||||
|
||||
@@ -118,17 +121,19 @@ def dequantize_mxfp4_weight(packed_uint8, scale_e8m0):
|
||||
def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices):
|
||||
"""Dequantize expert weights from original MXFP4 checkpoint.
|
||||
|
||||
Original checkpoint key format: layers.{L}.ffn.experts.{E}.{w1/w2/w3}.{weight/scale}
|
||||
w1 = gate_proj, w3 = up_proj, w2 = down_proj
|
||||
|
||||
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
|
||||
"""
|
||||
experts = {}
|
||||
for e in expert_indices:
|
||||
expert = {}
|
||||
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
||||
weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
|
||||
scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.scale"
|
||||
for proj, shard in [("gate_proj", "w1"), ("up_proj", "w3"), ("down_proj", "w2")]:
|
||||
weight_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.weight"
|
||||
scale_key = f"layers.{layer_idx}.ffn.experts.{e}.{shard}.scale"
|
||||
|
||||
if weight_key not in orig_tensors:
|
||||
# Expert 211 has no down_proj
|
||||
if proj == "down_proj" and e == 211:
|
||||
continue
|
||||
raise KeyError(f"Missing {weight_key}")
|
||||
@@ -174,6 +179,8 @@ def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
|
||||
def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
|
||||
"""Dequantize expert weights from NVFP4 checkpoint.
|
||||
|
||||
NVFP4 checkpoint key format: layers.{L}.mlp.experts.{E}.{gate_proj/up_proj/down_proj}.{weight/weight_scale/weight_scale_2}
|
||||
|
||||
Returns dict: {expert_id: {gate_proj, up_proj, down_proj}} each as BF16.
|
||||
"""
|
||||
experts = {}
|
||||
|
||||
Reference in New Issue
Block a user