Files
nvfp4-megamoe-kernel/dump_checkpoint_keys.py

37 lines
1.3 KiB
Python

#!/usr/bin/env python3
"""Dump checkpoint key names and shapes to help understand the model structure."""
import json
from pathlib import Path
from safetensors.torch import load_file
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
def main():
cdir = Path(CHECKPOINT_DIR)
index_path = cdir / "model.safetensors.index.json"
if index_path.exists():
with open(index_path) as f:
weight_map = json.load(f).get("weight_map", {})
# Collect unique key prefixes for layer 0 and layer 2 (CSA)
for li in [0, 1, 2, 3, 60]:
prefix = f"model.layers.{li}."
keys = sorted(k for k in weight_map if k.startswith(prefix))
print(f"\n=== Layer {li} keys ===")
for k in keys:
print(f" {k}")
else:
print("No index file found, loading first shard...")
shards = sorted(cdir.glob("model-*.safetensors"))
if shards:
data = load_file(str(shards[0]))
# Print layer 0 and 2 keys
for li in [0, 1, 2]:
prefix = f"model.layers.{li}."
keys = sorted(k for k in data if k.startswith(prefix))
print(f"\n=== Layer {li} keys (from {shards[0].name}) ===")
for k in keys:
print(f" {k}: {data[k].shape}")
if __name__ == "__main__":
main()