70 lines
2.8 KiB
Python
70 lines
2.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Probe actual checkpoint weight shapes for DSV4-Pro NVFP4.
|
|
Run on B200: prints key/shape for attention, compressor, mHC, FFN, global weights.
|
|
"""
|
|
import json, os, sys
|
|
from safetensors.torch import load_file
|
|
from pathlib import Path
|
|
|
|
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
cdir = Path(CHECKPOINT_DIR)
|
|
|
|
# Load index
|
|
with open(cdir / "model.safetensors.index.json") as f:
|
|
wmap = json.load(f)["weight_map"]
|
|
|
|
# 1. Find actual key patterns for layer 0 and 2 (HCA and CSA)
|
|
for li in [0, 2]:
|
|
print(f"\n=== Layer {li} keys ===")
|
|
pfx = f"model.layers.{li}."
|
|
keys = sorted([k for k in wmap if k.startswith(pfx) and not 'experts.' in k and not 'shared_expert' in k])
|
|
for k in keys:
|
|
print(f" {k}")
|
|
# Expert/shared count
|
|
expert_keys = [k for k in wmap if k.startswith(pfx) and 'experts.' in k]
|
|
shared_keys = [k for k in wmap if k.startswith(pfx) and 'shared_expert' in k]
|
|
print(f" + {len(expert_keys)} expert keys, {len(shared_keys)} shared expert keys")
|
|
|
|
# 2. Global keys
|
|
print(f"\n=== Global keys ===")
|
|
global_keys = sorted([k for k in wmap if not k.startswith("model.layers.")])
|
|
for k in global_keys:
|
|
print(f" {k}")
|
|
|
|
# 3. Load ONE shard to get actual shapes for layer 0
|
|
# Find shards containing layer 0 attention keys
|
|
pfx0 = f"model.layers.0."
|
|
attn_like_keys = [k for k in wmap if k.startswith(pfx0) and ('self_attn' in k or 'attn' in k or 'compressor' in k or 'hyper' in k or 'mhc' in k or 'input_layernorm' in k or 'post_attention_layernorm' in k)]
|
|
print(f"\n=== Layer 0 attn-like keys ===")
|
|
for k in attn_like_keys:
|
|
print(f" {k} -> shard {wmap[k]}")
|
|
|
|
# Load the first shard with layer 0 data
|
|
shards_needed = set(wmap[k] for k in attn_like_keys)
|
|
print(f"\nShards with layer 0 attn data: {shards_needed}")
|
|
|
|
for sn in sorted(shards_needed)[:2]:
|
|
if not (cdir / sn).exists(): continue
|
|
w = load_file(str(cdir / sn))
|
|
for k in sorted(w.keys()):
|
|
if k.startswith(pfx0) and 'experts.' not in k:
|
|
print(f" {k}: {tuple(w[k].shape)} dtype={w[k].dtype}")
|
|
|
|
# 4. Config
|
|
print(f"\n=== Config ===")
|
|
with open(cdir / "config.json") as f:
|
|
cfg = json.load(f)
|
|
for k in ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'head_dim',
|
|
'qk_rope_head_dim', 'n_routed_experts', 'num_experts_per_tok',
|
|
'routed_scaling_factor', 'num_output_groups', 'output_group_dim',
|
|
'query_compression_dim', 'compress_ratios', 'rope_theta',
|
|
'intermediate_size', 'kv_lora_rank', 'q_lora_rank',
|
|
'num_key_value_heads', 'indexer_num_heads', 'indexer_head_dim',
|
|
'n_shared_experts', 'first_k_dense_replace']:
|
|
if k in cfg:
|
|
v = cfg[k]
|
|
if isinstance(v, list) and len(v) > 10:
|
|
print(f" {k}: len={len(v)}, first5={v[:5]}")
|
|
else:
|
|
print(f" {k}: {v}")
|