fix: checkpoint keys don't have 'model.' prefix
This commit is contained in:
@@ -52,21 +52,19 @@ def find_shards(model_dir):
|
||||
return key_to_shard
|
||||
|
||||
|
||||
def load_layer_tensors(model_dir, layer_idx, prefix_filter=None):
|
||||
def load_layer_tensors(model_dir, layer_idx):
|
||||
"""Load all tensors for a specific layer from the checkpoint.
|
||||
|
||||
Returns dict of {key: tensor} for all keys matching the layer.
|
||||
"""
|
||||
key_to_shard = find_shards(model_dir)
|
||||
layer_prefix = f"model.layers.{layer_idx}."
|
||||
layer_prefix = f"layers.{layer_idx}."
|
||||
|
||||
# Group by shard to minimize file opens
|
||||
shard_to_keys = {}
|
||||
for key, shard in key_to_shard.items():
|
||||
if not key.startswith(layer_prefix):
|
||||
continue
|
||||
if prefix_filter and prefix_filter not in key:
|
||||
continue
|
||||
shard_to_keys.setdefault(shard, []).append(key)
|
||||
|
||||
tensors = {}
|
||||
@@ -126,8 +124,8 @@ def dequantize_mxfp4_experts(orig_tensors, layer_idx, expert_indices):
|
||||
for e in expert_indices:
|
||||
expert = {}
|
||||
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
||||
weight_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
|
||||
scale_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.scale"
|
||||
weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
|
||||
scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.scale"
|
||||
|
||||
if weight_key not in orig_tensors:
|
||||
# Expert 211 has no down_proj
|
||||
@@ -182,9 +180,9 @@ def dequantize_nvfp4_experts(nvfp4_tensors, layer_idx, expert_indices):
|
||||
for e in expert_indices:
|
||||
expert = {}
|
||||
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
||||
weight_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
|
||||
scale_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale"
|
||||
gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2"
|
||||
weight_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight"
|
||||
scale_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale"
|
||||
gs_key = f"layers.{layer_idx}.mlp.experts.{e}.{proj}.weight_scale_2"
|
||||
|
||||
if weight_key not in nvfp4_tensors:
|
||||
if proj == "down_proj" and e == 211:
|
||||
@@ -279,12 +277,12 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
|
||||
|
||||
for e in unique_experts:
|
||||
# L1: gate_proj + up_proj fused
|
||||
gate_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"
|
||||
gate_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"
|
||||
gate_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"
|
||||
up_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"
|
||||
up_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"
|
||||
up_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"
|
||||
gate_w_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"
|
||||
gate_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"
|
||||
gate_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"
|
||||
up_w_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"
|
||||
up_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"
|
||||
up_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"
|
||||
|
||||
gate_w = nvfp4_tensors[gate_w_key].view(torch.int8).to(DEVICE)
|
||||
gate_sf = nvfp4_tensors[gate_sf_key].to(DEVICE)
|
||||
@@ -303,11 +301,11 @@ def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, exper
|
||||
l1_global_scales.append(l1_gs)
|
||||
|
||||
# L2: down_proj
|
||||
down_w_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
|
||||
down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
|
||||
if down_w_key in nvfp4_tensors:
|
||||
down_w = nvfp4_tensors[down_w_key].view(torch.int8).to(DEVICE)
|
||||
down_sf_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"
|
||||
down_gs_key = f"model.layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"
|
||||
down_sf_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"
|
||||
down_gs_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"
|
||||
down_sf = nvfp4_tensors[down_sf_key].to(DEVICE)
|
||||
down_gs = nvfp4_tensors[down_gs_key].item()
|
||||
else:
|
||||
@@ -436,7 +434,7 @@ def main():
|
||||
# Verify dtype of weight_scale (should be float8_e4m3fn, NOT float8_e8m0fnu)
|
||||
for e in expert_indices[:1]:
|
||||
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
||||
key = f"model.layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale"
|
||||
key = f"layers.{LAYER_IDX}.mlp.experts.{e}.{proj}.weight_scale"
|
||||
if key in nvfp4_tensors:
|
||||
dt = nvfp4_tensors[key].dtype
|
||||
print(f" {proj}.weight_scale dtype = {dt} {'✓ E4M3' if dt == torch.float8_e4m3fn else '✗ WRONG (expected float8_e4m3fn)'}")
|
||||
|
||||
Reference in New Issue
Block a user