#!/usr/bin/env python3 """Probe actual checkpoint weight keys and shapes for DSV4-Pro NVFP4. Focus on layer 0 (HCA) and layer 2 (CSA) non-expert weights. """ import json, os from safetensors.torch import load_file from pathlib import Path CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" cdir = Path(CHECKPOINT_DIR) with open(cdir / "model.safetensors.index.json") as f: wmap = json.load(f)["weight_map"] # Find all keys for layer 0 and 2, excluding expert weights for li in [0, 2]: pfx = f"model.layers.{li}." keys = sorted([k for k in wmap if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k]) print(f"\n=== Layer {li} non-expert keys ({len(keys)}) ===") for k in keys: print(f" {k}") # Global keys global_keys = sorted([k for k in wmap if not k.startswith("model.layers.")]) print(f"\n=== Global keys ({len(global_keys)}) ===") for k in global_keys: print(f" {k}") # Load actual shapes for layer 0 and 2 for li in [0, 2]: pfx = f"model.layers.{li}." keys = [k for k in wmap if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k] shards = set(wmap[k] for k in keys) print(f"\n=== Layer {li} shapes ===") for sn in sorted(shards): if not (cdir / sn).exists(): continue w = load_file(str(cdir / sn)) for k in sorted(w.keys()): if k.startswith(pfx) and 'experts.' not in k and 'shared_expert' not in k: print(f" {k}: {tuple(w[k].shape)} {w[k].dtype}") # Expert key pattern for layer 0 (just the first few) pfx0 = f"model.layers.0.mlp." expert_keys = sorted([k for k in wmap if k.startswith(pfx0) and 'experts.' in k])[:12] print(f"\n=== Layer 0 expert keys (sample) ===") for k in expert_keys: print(f" {k}") # Shared expert keys for layer 0 shared_keys = sorted([k for k in wmap if k.startswith(pfx0) and 'shared_expert' in k]) print(f"\n=== Layer 0 shared expert keys ===") for k in shared_keys: print(f" {k}") # Router/gate keys for layer 0 gate_keys = sorted([k for k in wmap if k.startswith(pfx0) and ('gate' in k or 'e_score' in k or 'tid2eid' in k or 'router' in k)]) print(f"\n=== Layer 0 gate/router keys ===") for k in gate_keys: print(f" {k}")