Files
nvfp4-megamoe-kernel/probe_keys.py
2026-05-31 21:42:52 +00:00

60 lines
2.2 KiB
Python

#!/usr/bin/env python3
"""Probe actual checkpoint weight keys and shapes for DSV4-Pro NVFP4.
Focus on layer 0 (HCA) and layer 2 (CSA) non-expert weights.
"""
import json, os
from safetensors.torch import load_file
from pathlib import Path
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
cdir = Path(CHECKPOINT_DIR)
with open(cdir / "model.safetensors.index.json") as f:
wmap = json.load(f)["weight_map"]
# Find all keys for layer 0 and 2, excluding expert weights
for li in [0, 2]:
pfx = f"model.layers.{li}."
keys = sorted([k for k in wmap if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k])
print(f"\n=== Layer {li} non-expert keys ({len(keys)}) ===")
for k in keys:
print(f" {k}")
# Global keys
global_keys = sorted([k for k in wmap if not k.startswith("model.layers.")])
print(f"\n=== Global keys ({len(global_keys)}) ===")
for k in global_keys:
print(f" {k}")
# Load actual shapes for layer 0 and 2
for li in [0, 2]:
pfx = f"model.layers.{li}."
keys = [k for k in wmap if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k]
shards = set(wmap[k] for k in keys)
print(f"\n=== Layer {li} shapes ===")
for sn in sorted(shards):
if not (cdir / sn).exists(): continue
w = load_file(str(cdir / sn))
for k in sorted(w.keys()):
if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k:
print(f" {k}: {tuple(w[k].shape)} {w[k].dtype}")
# Expert key pattern for layer 0 (just the first few)
pfx0 = f"model.layers.0.mlp."
expert_keys = sorted([k for k in wmap if k.startswith(pfx0) and 'experts.' in k])[:12]
print(f"\n=== Layer 0 expert keys (sample) ===")
for k in expert_keys:
print(f" {k}")
# Shared expert keys for layer 0
shared_keys = sorted([k for k in wmap if k.startswith(pfx0) and 'shared_expert' in k])
print(f"\n=== Layer 0 shared expert keys ===")
for k in shared_keys:
print(f" {k}")
# Router/gate keys for layer 0
gate_keys = sorted([k for k in wmap if k.startswith(pfx0) and ('gate' in k or 'e_score' in k or 'tid2eid' in k or 'router' in k)])
print(f"\n=== Layer 0 gate/router keys ===")
for k in gate_keys:
print(f" {k}")