add probe_shapes script

This commit is contained in:
2026-05-31 21:41:31 +00:00
parent c54dd15550
commit ba915dbd53

69
probe_shapes.py Normal file
View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python3
"""Probe actual checkpoint weight shapes for DSV4-Pro NVFP4.
Run on B200: prints key/shape for attention, compressor, mHC, FFN, global weights.
"""
import json, os, sys
from safetensors.torch import load_file
from pathlib import Path
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
cdir = Path(CHECKPOINT_DIR)
# Load index
with open(cdir / "model.safetensors.index.json") as f:
wmap = json.load(f)["weight_map"]
# 1. Find actual key patterns for layer 0 and 2 (HCA and CSA)
for li in [0, 2]:
print(f"\n=== Layer {li} keys ===")
pfx = f"model.layers.{li}."
keys = sorted([k for k in wmap if k.startswith(pfx) and not 'experts.' in k and not 'shared_expert' in k])
for k in keys:
print(f" {k}")
# Expert/shared count
expert_keys = [k for k in wmap if k.startswith(pfx) and 'experts.' in k]
shared_keys = [k for k in wmap if k.startswith(pfx) and 'shared_expert' in k]
print(f" + {len(expert_keys)} expert keys, {len(shared_keys)} shared expert keys")
# 2. Global keys
print(f"\n=== Global keys ===")
global_keys = sorted([k for k in wmap if not k.startswith("model.layers.")])
for k in global_keys:
print(f" {k}")
# 3. Load ONE shard to get actual shapes for layer 0
# Find shards containing layer 0 attention keys
pfx0 = f"model.layers.0."
attn_like_keys = [k for k in wmap if k.startswith(pfx0) and ('self_attn' in k or 'attn' in k or 'compressor' in k or 'hyper' in k or 'mhc' in k or 'input_layernorm' in k or 'post_attention_layernorm' in k)]
print(f"\n=== Layer 0 attn-like keys ===")
for k in attn_like_keys:
print(f" {k} -> shard {wmap[k]}")
# Load the first shard with layer 0 data
shards_needed = set(wmap[k] for k in attn_like_keys)
print(f"\nShards with layer 0 attn data: {shards_needed}")
for sn in sorted(shards_needed)[:2]:
if not (cdir / sn).exists(): continue
w = load_file(str(cdir / sn))
for k in sorted(w.keys()):
if k.startswith(pfx0) and 'experts.' not in k:
print(f" {k}: {tuple(w[k].shape)} dtype={w[k].dtype}")
# 4. Config
print(f"\n=== Config ===")
with open(cdir / "config.json") as f:
cfg = json.load(f)
for k in ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'head_dim',
'qk_rope_head_dim', 'n_routed_experts', 'num_experts_per_tok',
'routed_scaling_factor', 'num_output_groups', 'output_group_dim',
'query_compression_dim', 'compress_ratios', 'rope_theta',
'intermediate_size', 'kv_lora_rank', 'q_lora_rank',
'num_key_value_heads', 'indexer_num_heads', 'indexer_head_dim',
'n_shared_experts', 'first_k_dense_replace']:
if k in cfg:
v = cfg[k]
if isinstance(v, list) and len(v) > 10:
print(f" {k}: len={len(v)}, first5={v[:5]}")
else:
print(f" {k}: {v}")