173 lines
6.5 KiB
Python
173 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
"""Inspect a DeepSeek FP8 model directory and report on tensor structure.
|
|
|
|
Usage: python inspect_model.py <model_dir>
|
|
|
|
Prints:
|
|
- Total tensor count and dtype histogram
|
|
- Sample of tensor names by category (lm_head, embeddings, attention, MoE experts, norms, etc.)
|
|
- FP8 block scaling structure (block size detection)
|
|
- MoE expert layer count and routing structure
|
|
- Any "unusual" tensors that need manual classification
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
import sys
|
|
from collections import Counter, defaultdict
|
|
from pathlib import Path
|
|
|
|
from safetensors import safe_open
|
|
|
|
|
|
# Patterns we'd preserve (skip quantization on)
|
|
PRESERVE_PATTERNS = [
|
|
(re.compile(r".*lm_head.*"), "lm_head"),
|
|
(re.compile(r".*embed_tokens.*"), "embeddings"),
|
|
(re.compile(r".*\.mlp\.gate(\.weight)?$"), "moe_router_gate"),
|
|
(re.compile(r".*norm.*"), "normalization"),
|
|
(re.compile(r".*indexer.*"), "attention_indexer"), # V3.2 DSA / V4 CSA?
|
|
(re.compile(r".*hyper_conn.*"), "mhc_hyper_conn"), # V4 mHC
|
|
(re.compile(r".*mhc.*"), "mhc_other"),
|
|
(re.compile(r".*scoring.*"), "scoring"),
|
|
]
|
|
|
|
# Patterns for MoE expert weights (these are what we WILL quantize)
|
|
EXPERT_PATTERNS = [
|
|
(re.compile(r".*experts\.\d+\.gate_proj.*"), "expert_gate_proj"),
|
|
(re.compile(r".*experts\.\d+\.up_proj.*"), "expert_up_proj"),
|
|
(re.compile(r".*experts\.\d+\.down_proj.*"), "expert_down_proj"),
|
|
(re.compile(r".*shared_experts?\.gate_proj.*"), "shared_gate_proj"),
|
|
(re.compile(r".*shared_experts?\.up_proj.*"), "shared_up_proj"),
|
|
(re.compile(r".*shared_experts?\.down_proj.*"), "shared_down_proj"),
|
|
]
|
|
|
|
|
|
def categorize(name):
|
|
for pat, cat in PRESERVE_PATTERNS:
|
|
if pat.match(name):
|
|
return ("preserve", cat)
|
|
for pat, cat in EXPERT_PATTERNS:
|
|
if pat.match(name):
|
|
return ("quantize_expert", cat)
|
|
if name.endswith(".weight_scale_inv"):
|
|
return ("scale_metadata", "fp8_block_scale")
|
|
if name.endswith(".weight"):
|
|
return ("quantize_other", "linear_weight")
|
|
return ("other", "uncategorized")
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("model_dir")
|
|
ap.add_argument("--show-samples", type=int, default=5,
|
|
help="How many sample names to show per category")
|
|
args = ap.parse_args()
|
|
|
|
model_dir = Path(args.model_dir)
|
|
index_path = model_dir / "model.safetensors.index.json"
|
|
if not index_path.exists():
|
|
print(f"ERROR: {index_path} not found", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
with open(index_path) as f:
|
|
index = json.load(f)
|
|
weight_map = index["weight_map"]
|
|
total_size = index.get("metadata", {}).get("total_size")
|
|
|
|
print(f"=== {model_dir} ===")
|
|
print(f"Total tensors: {len(weight_map):,}")
|
|
print(f"Total shards: {len(set(weight_map.values()))}")
|
|
if total_size:
|
|
print(f"Reported size: {total_size / 1024**3:.1f} GB")
|
|
print()
|
|
|
|
# Categorize names (cheap, no tensor loading)
|
|
categories = defaultdict(list)
|
|
for name in weight_map:
|
|
kind, cat = categorize(name)
|
|
categories[(kind, cat)].append(name)
|
|
|
|
print("=== Tensor categorization ===")
|
|
for (kind, cat), names in sorted(categories.items()):
|
|
print(f" [{kind:18s}] {cat:25s} count={len(names):,}")
|
|
for n in names[: args.show_samples]:
|
|
print(f" {n}")
|
|
if len(names) > args.show_samples:
|
|
print(f" ... and {len(names) - args.show_samples} more")
|
|
print()
|
|
|
|
# Inspect dtypes and FP8 block scaling on a sample shard
|
|
sample_shard = model_dir / sorted(set(weight_map.values()))[0]
|
|
print(f"=== Sampling dtypes from {sample_shard.name} ===")
|
|
dtype_hist = Counter()
|
|
fp8_block_sizes = Counter()
|
|
weight_with_scale = []
|
|
|
|
with safe_open(sample_shard, framework="pt") as f:
|
|
names_in_shard = list(f.keys())
|
|
for name in names_in_shard:
|
|
t = f.get_tensor(name)
|
|
dtype_hist[str(t.dtype)] += 1
|
|
|
|
# Check for FP8 weight + scale_inv pair
|
|
if name.endswith(".weight") and t.dtype.is_floating_point and t.element_size() == 1:
|
|
scale_name = name.replace(".weight", ".weight_scale_inv")
|
|
if scale_name in names_in_shard:
|
|
scale_t = f.get_tensor(scale_name)
|
|
bm = t.shape[0] / scale_t.shape[0] if scale_t.dim() == 2 else None
|
|
bn = t.shape[1] / scale_t.shape[1] if scale_t.dim() == 2 and t.dim() == 2 else None
|
|
fp8_block_sizes[(bm, bn)] += 1
|
|
if len(weight_with_scale) < 3:
|
|
weight_with_scale.append((name, t.shape, t.dtype, scale_t.shape, scale_t.dtype))
|
|
|
|
print(" Dtype histogram (this shard only):")
|
|
for d, c in dtype_hist.most_common():
|
|
print(f" {d:20s} {c:,}")
|
|
|
|
print()
|
|
print(" FP8 block-scale dimensions detected:")
|
|
for (bm, bn), c in fp8_block_sizes.most_common():
|
|
print(f" block_size = ({bm}, {bn}) count={c}")
|
|
|
|
print()
|
|
print(" Sample FP8 weight + scale_inv pairs:")
|
|
for name, wshape, wdt, sshape, sdt in weight_with_scale:
|
|
print(f" {name}")
|
|
print(f" weight: shape={tuple(wshape)} dtype={wdt}")
|
|
print(f" scale: shape={tuple(sshape)} dtype={sdt}")
|
|
|
|
# MoE structure summary
|
|
print()
|
|
print("=== MoE structure summary ===")
|
|
layer_experts = defaultdict(set)
|
|
for name in weight_map:
|
|
m = re.match(r".*layers\.(\d+)\..*experts\.(\d+)\..*", name)
|
|
if m:
|
|
layer_experts[int(m.group(1))].add(int(m.group(2)))
|
|
if layer_experts:
|
|
layer_count = len(layer_experts)
|
|
expert_counts = [len(v) for v in layer_experts.values()]
|
|
print(f" Layers with MoE experts: {layer_count}")
|
|
print(f" Experts per layer: min={min(expert_counts)} max={max(expert_counts)}")
|
|
print(f" Sample layer 0 experts: {sorted(list(layer_experts[min(layer_experts)]))[:5]}...")
|
|
else:
|
|
print(" No '.experts.N.' pattern found — MoE structure may use different naming.")
|
|
|
|
# Flag uncategorized for human review
|
|
print()
|
|
print("=== Uncategorized tensors (review these manually) ===")
|
|
uncat = categories.get(("other", "uncategorized"), [])
|
|
if uncat:
|
|
print(f" {len(uncat):,} tensors:")
|
|
for n in uncat[:20]:
|
|
print(f" {n}")
|
|
if len(uncat) > 20:
|
|
print(f" ... and {len(uncat) - 20} more")
|
|
else:
|
|
print(" None — every tensor matched a known pattern.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |