Files
deepseek-v4-quant/inspect_model.py

173 lines
6.5 KiB
Python
Raw Permalink Normal View History

2026-05-06 23:47:07 +00:00
#!/usr/bin/env python3
"""Inspect a DeepSeek FP8 model directory and report on tensor structure.
Usage: python inspect_model.py <model_dir>
Prints:
- Total tensor count and dtype histogram
- Sample of tensor names by category (lm_head, embeddings, attention, MoE experts, norms, etc.)
- FP8 block scaling structure (block size detection)
- MoE expert layer count and routing structure
- Any "unusual" tensors that need manual classification
"""
import argparse
import json
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
from safetensors import safe_open
# Patterns we'd preserve (skip quantization on)
PRESERVE_PATTERNS = [
(re.compile(r".*lm_head.*"), "lm_head"),
(re.compile(r".*embed_tokens.*"), "embeddings"),
(re.compile(r".*\.mlp\.gate(\.weight)?$"), "moe_router_gate"),
(re.compile(r".*norm.*"), "normalization"),
(re.compile(r".*indexer.*"), "attention_indexer"), # V3.2 DSA / V4 CSA?
(re.compile(r".*hyper_conn.*"), "mhc_hyper_conn"), # V4 mHC
(re.compile(r".*mhc.*"), "mhc_other"),
(re.compile(r".*scoring.*"), "scoring"),
]
# Patterns for MoE expert weights (these are what we WILL quantize)
EXPERT_PATTERNS = [
(re.compile(r".*experts\.\d+\.gate_proj.*"), "expert_gate_proj"),
(re.compile(r".*experts\.\d+\.up_proj.*"), "expert_up_proj"),
(re.compile(r".*experts\.\d+\.down_proj.*"), "expert_down_proj"),
(re.compile(r".*shared_experts?\.gate_proj.*"), "shared_gate_proj"),
(re.compile(r".*shared_experts?\.up_proj.*"), "shared_up_proj"),
(re.compile(r".*shared_experts?\.down_proj.*"), "shared_down_proj"),
]
def categorize(name):
for pat, cat in PRESERVE_PATTERNS:
if pat.match(name):
return ("preserve", cat)
for pat, cat in EXPERT_PATTERNS:
if pat.match(name):
return ("quantize_expert", cat)
if name.endswith(".weight_scale_inv"):
return ("scale_metadata", "fp8_block_scale")
if name.endswith(".weight"):
return ("quantize_other", "linear_weight")
return ("other", "uncategorized")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("model_dir")
ap.add_argument("--show-samples", type=int, default=5,
help="How many sample names to show per category")
args = ap.parse_args()
model_dir = Path(args.model_dir)
index_path = model_dir / "model.safetensors.index.json"
if not index_path.exists():
print(f"ERROR: {index_path} not found", file=sys.stderr)
sys.exit(1)
with open(index_path) as f:
index = json.load(f)
weight_map = index["weight_map"]
total_size = index.get("metadata", {}).get("total_size")
print(f"=== {model_dir} ===")
print(f"Total tensors: {len(weight_map):,}")
print(f"Total shards: {len(set(weight_map.values()))}")
if total_size:
print(f"Reported size: {total_size / 1024**3:.1f} GB")
print()
# Categorize names (cheap, no tensor loading)
categories = defaultdict(list)
for name in weight_map:
kind, cat = categorize(name)
categories[(kind, cat)].append(name)
print("=== Tensor categorization ===")
for (kind, cat), names in sorted(categories.items()):
print(f" [{kind:18s}] {cat:25s} count={len(names):,}")
for n in names[: args.show_samples]:
print(f" {n}")
if len(names) > args.show_samples:
print(f" ... and {len(names) - args.show_samples} more")
print()
# Inspect dtypes and FP8 block scaling on a sample shard
sample_shard = model_dir / sorted(set(weight_map.values()))[0]
print(f"=== Sampling dtypes from {sample_shard.name} ===")
dtype_hist = Counter()
fp8_block_sizes = Counter()
weight_with_scale = []
with safe_open(sample_shard, framework="pt") as f:
names_in_shard = list(f.keys())
for name in names_in_shard:
t = f.get_tensor(name)
dtype_hist[str(t.dtype)] += 1
# Check for FP8 weight + scale_inv pair
if name.endswith(".weight") and t.dtype.is_floating_point and t.element_size() == 1:
scale_name = name.replace(".weight", ".weight_scale_inv")
if scale_name in names_in_shard:
scale_t = f.get_tensor(scale_name)
bm = t.shape[0] / scale_t.shape[0] if scale_t.dim() == 2 else None
bn = t.shape[1] / scale_t.shape[1] if scale_t.dim() == 2 and t.dim() == 2 else None
fp8_block_sizes[(bm, bn)] += 1
if len(weight_with_scale) < 3:
weight_with_scale.append((name, t.shape, t.dtype, scale_t.shape, scale_t.dtype))
print(" Dtype histogram (this shard only):")
for d, c in dtype_hist.most_common():
print(f" {d:20s} {c:,}")
print()
print(" FP8 block-scale dimensions detected:")
for (bm, bn), c in fp8_block_sizes.most_common():
print(f" block_size = ({bm}, {bn}) count={c}")
print()
print(" Sample FP8 weight + scale_inv pairs:")
for name, wshape, wdt, sshape, sdt in weight_with_scale:
print(f" {name}")
print(f" weight: shape={tuple(wshape)} dtype={wdt}")
print(f" scale: shape={tuple(sshape)} dtype={sdt}")
# MoE structure summary
print()
print("=== MoE structure summary ===")
layer_experts = defaultdict(set)
for name in weight_map:
m = re.match(r".*layers\.(\d+)\..*experts\.(\d+)\..*", name)
if m:
layer_experts[int(m.group(1))].add(int(m.group(2)))
if layer_experts:
layer_count = len(layer_experts)
expert_counts = [len(v) for v in layer_experts.values()]
print(f" Layers with MoE experts: {layer_count}")
print(f" Experts per layer: min={min(expert_counts)} max={max(expert_counts)}")
print(f" Sample layer 0 experts: {sorted(list(layer_experts[min(layer_experts)]))[:5]}...")
else:
print(" No '.experts.N.' pattern found — MoE structure may use different naming.")
# Flag uncategorized for human review
print()
print("=== Uncategorized tensors (review these manually) ===")
uncat = categories.get(("other", "uncategorized"), [])
if uncat:
print(f" {len(uncat):,} tensors:")
for n in uncat[:20]:
print(f" {n}")
if len(uncat) > 20:
print(f" ... and {len(uncat) - 20} more")
else:
print(" None — every tensor matched a known pattern.")
if __name__ == "__main__":
main()