60 lines
2.2 KiB
Python
60 lines
2.2 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Probe actual checkpoint weight keys and shapes for DSV4-Pro NVFP4.
|
||
|
|
Focus on layer 0 (HCA) and layer 2 (CSA) non-expert weights.
|
||
|
|
"""
|
||
|
|
import json, os
|
||
|
|
from safetensors.torch import load_file
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||
|
|
cdir = Path(CHECKPOINT_DIR)
|
||
|
|
|
||
|
|
with open(cdir / "model.safetensors.index.json") as f:
|
||
|
|
wmap = json.load(f)["weight_map"]
|
||
|
|
|
||
|
|
# Find all keys for layer 0 and 2, excluding expert weights
|
||
|
|
for li in [0, 2]:
|
||
|
|
pfx = f"model.layers.{li}."
|
||
|
|
keys = sorted([k for k in wmap if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k])
|
||
|
|
print(f"\n=== Layer {li} non-expert keys ({len(keys)}) ===")
|
||
|
|
for k in keys:
|
||
|
|
print(f" {k}")
|
||
|
|
|
||
|
|
# Global keys
|
||
|
|
global_keys = sorted([k for k in wmap if not k.startswith("model.layers.")])
|
||
|
|
print(f"\n=== Global keys ({len(global_keys)}) ===")
|
||
|
|
for k in global_keys:
|
||
|
|
print(f" {k}")
|
||
|
|
|
||
|
|
# Load actual shapes for layer 0 and 2
|
||
|
|
for li in [0, 2]:
|
||
|
|
pfx = f"model.layers.{li}."
|
||
|
|
keys = [k for k in wmap if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k]
|
||
|
|
shards = set(wmap[k] for k in keys)
|
||
|
|
print(f"\n=== Layer {li} shapes ===")
|
||
|
|
for sn in sorted(shards):
|
||
|
|
if not (cdir / sn).exists(): continue
|
||
|
|
w = load_file(str(cdir / sn))
|
||
|
|
for k in sorted(w.keys()):
|
||
|
|
if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k:
|
||
|
|
print(f" {k}: {tuple(w[k].shape)} {w[k].dtype}")
|
||
|
|
|
||
|
|
# Expert key pattern for layer 0 (just the first few)
|
||
|
|
pfx0 = f"model.layers.0.mlp."
|
||
|
|
expert_keys = sorted([k for k in wmap if k.startswith(pfx0) and 'experts.' in k])[:12]
|
||
|
|
print(f"\n=== Layer 0 expert keys (sample) ===")
|
||
|
|
for k in expert_keys:
|
||
|
|
print(f" {k}")
|
||
|
|
|
||
|
|
# Shared expert keys for layer 0
|
||
|
|
shared_keys = sorted([k for k in wmap if k.startswith(pfx0) and 'shared_expert' in k])
|
||
|
|
print(f"\n=== Layer 0 shared expert keys ===")
|
||
|
|
for k in shared_keys:
|
||
|
|
print(f" {k}")
|
||
|
|
|
||
|
|
# Router/gate keys for layer 0
|
||
|
|
gate_keys = sorted([k for k in wmap if k.startswith(pfx0) and ('gate' in k or 'e_score' in k or 'tid2eid' in k or 'router' in k)])
|
||
|
|
print(f"\n=== Layer 0 gate/router keys ===")
|
||
|
|
for k in gate_keys:
|
||
|
|
print(f" {k}")
|