fix: checkpoint keys don't have 'model.' prefix

This commit is contained in:
2026-05-16 02:17:13 +00:00
parent ea5ee7c1f7
commit bdf9f31ae2

View File

@@ -52,21 +52,19 @@ def find_shards(model_dir):
return key_to_shard
def load_layer_tensors(model_dir, layer_idx, prefix_filter=None):
def load_layer_tensors(model_dir, layer_idx):
"""Load all tensors for a specific layer from the checkpoint.
Returns dict of {key: tensor} for all keys matching the layer.
"""
key_to_shard = find_shards(model_dir)
layer_prefix = f"model.layers.{layer_idx}."
layer_prefix = f"layers.{layer_idx}."
# Group by shard to minimize file opens
shard_to_keys = {}
for key, shard in key_to_shard.items():
if not key.startswith(layer_prefix):
continue
if prefix_filter and prefix_filter not in key:
continue
shard_to_keys.setdefault(shard, []).append(key)
tensors = {}
@@ -126,8 +124,8 @@ def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices):
for e in expert_indices:
expert = {}
for proj in ["gate_proj", "up_proj", "down_proj"]:
weight_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
scale_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.scale"
weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.scale"
if weight_key not in orig_tensors:
# Expert 211 has no down_proj
@@ -182,9 +180,9 @@ def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
for e in expert_indices:
expert = {}
for proj in ["gate_proj", "up_proj", "down_proj"]:
weight_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
scale_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale"
gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2"
weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale"
gs_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2"
if weight_key not in nvfp4_tensors:
if proj == "down_proj" and e == 211:
@@ -279,12 +277,12 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
for e in unique_experts:
# L1: gate_proj + up_proj fused
gate_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"
gate_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"
gate_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"
up_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"
up_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"
up_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"
gate_w_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"
gate_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"
gate_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"
up_w_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"
up_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"
up_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"
gate_w = nvfp4_tensors[gate_w_key].view(torch.int8).to(DEVICE)
gate_sf = nvfp4_tensors[gate_sf_key].to(DEVICE)
@@ -303,11 +301,11 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
l1_global_scales.append(l1_gs)
# L2: down_proj
down_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
if down_w_key in nvfp4_tensors:
down_w = nvfp4_tensors[down_w_key].view(torch.int8).to(DEVICE)
down_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"
down_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"
down_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"
down_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"
down_sf = nvfp4_tensors[down_sf_key].to(DEVICE)
down_gs = nvfp4_tensors[down_gs_key].item()
else:
@@ -436,7 +434,7 @@ def main():
# Verify dtype of weight_scale (should be float8_e4m3fn, NOT float8_e8m0fnu)
for e in expert_indices[:1]:
for proj in ["gate_proj", "up_proj", "down_proj"]:
key = f"model.layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale"
key = f"layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale"
if key in nvfp4_tensors:
dt = nvfp4_tensors[key].dtype
print(f" {proj}.weight_scale dtype = {dt} {'✓ E4M3' if dt == torch.float8_e4m3fn else '✗ WRONG (expected float8_e4m3fn)'}")