#!/usr/bin/env python3 """Probe actual checkpoint weight shapes for DSV4-Pro NVFP4. Run on B200: prints key/shape for attention, compressor, mHC, FFN, global weights. """ import json, os, sys from safetensors.torch import load_file from pathlib import Path CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" cdir = Path(CHECKPOINT_DIR) # Load index with open(cdir / "model.safetensors.index.json") as f: wmap = json.load(f)["weight_map"] # 1. Find actual key patterns for layer 0 and 2 (HCA and CSA) for li in [0, 2]: print(f"\n=== Layer {li} keys ===") pfx = f"model.layers.{li}." keys = sorted([k for k in wmap if k.startswith(pfx) and not 'experts.' in k and not 'shared_expert' in k]) for k in keys: print(f" {k}") # Expert/shared count expert_keys = [k for k in wmap if k.startswith(pfx) and 'experts.' in k] shared_keys = [k for k in wmap if k.startswith(pfx) and 'shared_expert' in k] print(f" + {len(expert_keys)} expert keys, {len(shared_keys)} shared expert keys") # 2. Global keys print(f"\n=== Global keys ===") global_keys = sorted([k for k in wmap if not k.startswith("model.layers.")]) for k in global_keys: print(f" {k}") # 3. Load ONE shard to get actual shapes for layer 0 # Find shards containing layer 0 attention keys pfx0 = f"model.layers.0." attn_like_keys = [k for k in wmap if k.startswith(pfx0) and ('self_attn' in k or 'attn' in k or 'compressor' in k or 'hyper' in k or 'mhc' in k or 'input_layernorm' in k or 'post_attention_layernorm' in k)] print(f"\n=== Layer 0 attn-like keys ===") for k in attn_like_keys: print(f" {k} -> shard {wmap[k]}") # Load the first shard with layer 0 data shards_needed = set(wmap[k] for k in attn_like_keys) print(f"\nShards with layer 0 attn data: {shards_needed}") for sn in sorted(shards_needed)[:2]: if not (cdir / sn).exists(): continue w = load_file(str(cdir / sn)) for k in sorted(w.keys()): if k.startswith(pfx0) and 'experts.' not in k: print(f" {k}: {tuple(w[k].shape)} dtype={w[k].dtype}") # 4. Config print(f"\n=== Config ===") with open(cdir / "config.json") as f: cfg = json.load(f) for k in ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'head_dim', 'qk_rope_head_dim', 'n_routed_experts', 'num_experts_per_tok', 'routed_scaling_factor', 'num_output_groups', 'output_group_dim', 'query_compression_dim', 'compress_ratios', 'rope_theta', 'intermediate_size', 'kv_lora_rank', 'q_lora_rank', 'num_key_value_heads', 'indexer_num_heads', 'indexer_head_dim', 'n_shared_experts', 'first_k_dense_replace']: if k in cfg: v = cfg[k] if isinstance(v, list) and len(v) > 10: print(f" {k}: len={len(v)}, first5={v[:5]}") else: print(f" {k}: {v}")