add checkpoint key dump script
This commit is contained in:
36
dump_checkpoint_keys.py
Normal file
36
dump_checkpoint_keys.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Dump checkpoint key names and shapes to help understand the model structure."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
|
||||
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
|
||||
def main():
|
||||
cdir = Path(CHECKPOINT_DIR)
|
||||
index_path = cdir / "model.safetensors.index.json"
|
||||
if index_path.exists():
|
||||
with open(index_path) as f:
|
||||
weight_map = json.load(f).get("weight_map", {})
|
||||
# Collect unique key prefixes for layer 0 and layer 2 (CSA)
|
||||
for li in [0, 1, 2, 3, 60]:
|
||||
prefix = f"model.layers.{li}."
|
||||
keys = sorted(k for k in weight_map if k.startswith(prefix))
|
||||
print(f"\n=== Layer {li} keys ===")
|
||||
for k in keys:
|
||||
print(f" {k}")
|
||||
else:
|
||||
print("No index file found, loading first shard...")
|
||||
shards = sorted(cdir.glob("model-*.safetensors"))
|
||||
if shards:
|
||||
data = load_file(str(shards[0]))
|
||||
# Print layer 0 and 2 keys
|
||||
for li in [0, 1, 2]:
|
||||
prefix = f"model.layers.{li}."
|
||||
keys = sorted(k for k in data if k.startswith(prefix))
|
||||
print(f"\n=== Layer {li} keys (from {shards[0].name}) ===")
|
||||
for k in keys:
|
||||
print(f" {k}: {data[k].shape}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user