Files
deepseek-v4-quant/verify_nvfp4.py
2026-05-06 23:47:07 +00:00

179 lines
6.3 KiB
Python

#!/usr/bin/env python3
"""Sanity check an NVFP4 DeepSeek V4 Pro checkpoint.
Two modes:
1) --tensor-only (default): no model loading. Just inspects the safetensors
shards: confirms NVFP4 packing structure (uint8 weight + FP8 weight_scale
+ FP32 weight_scale_2), checks for NaN/Inf in scales, samples a few
dequantizations to confirm they look plausible.
2) --vllm: tries to load the model with vLLM and generate a few tokens.
Requires vLLM with NVFP4 support (SM100+ Blackwell GPU).
Usage:
python verify_nvfp4.py DeepSeek-V4-Pro-NVFP4-streaming
python verify_nvfp4.py DeepSeek-V4-Pro-NVFP4-streaming --vllm
"""
import argparse
import json
import sys
from pathlib import Path
import torch
from safetensors import safe_open
FP4_E2M1_VALUES = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
dtype=torch.float32,
)
def unpack_fp4(packed: torch.Tensor) -> torch.Tensor:
"""Reverse the (low | high<<4) byte pack into a [M, N] tensor of FP4 indices."""
low = packed & 0x0F
high = (packed >> 4) & 0x0F
M, N_half = packed.shape
out = torch.empty(M, N_half * 2, dtype=torch.uint8)
out[:, ::2] = low
out[:, 1::2] = high
return out
def dequant_nvfp4(packed_uint8, weight_scale_fp8, weight_scale_2_fp32):
"""Reconstruct FP32 values from NVFP4 storage."""
fp4_idx = unpack_fp4(packed_uint8)
values = FP4_E2M1_VALUES[fp4_idx.long()] # [M, N]
M, N = values.shape
# Per-block scale broadcast back over 16 elements
scale_blocks = weight_scale_fp8.float() # [M, N//16]
scale_per_elem = scale_blocks.unsqueeze(-1).expand(-1, -1, 16).reshape(M, N)
return values * scale_per_elem * weight_scale_2_fp32.float()
def tensor_only_check(model_dir: Path):
index_path = model_dir / "model.safetensors.index.json"
if not index_path.exists():
sys.exit(f"No index.json at {model_dir}")
with open(index_path) as f:
index = json.load(f)
weight_map = index["weight_map"]
# Find one quantized weight to sample
sample = None
for name, fn in weight_map.items():
if name.endswith(".weight") and (name.replace(".weight", ".weight_scale") in weight_map):
sample = name
break
if not sample:
sys.exit("Couldn't find an NVFP4-quantized weight (expected *.weight_scale companion).")
print(f"Sampling: {sample}")
shard_fn = weight_map[sample]
scale_name = sample.replace(".weight", ".weight_scale")
scale_2_name = sample.replace(".weight", ".weight_scale_2")
scale_shard = weight_map[scale_name]
scale_2_shard = weight_map[scale_2_name]
def open_get(fn, name):
with safe_open(model_dir / fn, framework="pt") as f:
return f.get_tensor(name)
packed = open_get(shard_fn, sample)
weight_scale = open_get(scale_shard, scale_name)
weight_scale_2 = open_get(scale_2_shard, scale_2_name)
print(f" packed: shape={tuple(packed.shape)} dtype={packed.dtype}")
print(f" weight_scale: shape={tuple(weight_scale.shape)} dtype={weight_scale.dtype}")
print(f" weight_scale_2: shape={tuple(weight_scale_2.shape)} dtype={weight_scale_2.dtype} "
f"value={weight_scale_2.float().item():.6e}")
# Structural assertions
M = packed.shape[0]
assert packed.dtype == torch.uint8, f"packed should be uint8, got {packed.dtype}"
assert weight_scale.dtype == torch.float8_e4m3fn, \
f"weight_scale should be FP8 E4M3, got {weight_scale.dtype}"
assert weight_scale.shape == (M, packed.shape[1] * 2 // 16), \
f"weight_scale shape {weight_scale.shape} doesn't match expected (M, N/16)"
# Check for NaN/Inf in scales
s_fp32 = weight_scale.float()
assert torch.isfinite(s_fp32).all(), "weight_scale contains NaN/Inf"
assert torch.isfinite(weight_scale_2.float()).all(), "weight_scale_2 is NaN/Inf"
print(f" scales: all finite ✓")
print(f" weight_scale stats: min={s_fp32.min().item():.3e} max={s_fp32.max().item():.3e} "
f"mean={s_fp32.mean().item():.3e}")
# Spot-check dequantization
print("\nDequantizing first 4x32 block for visual check:")
rec = dequant_nvfp4(packed[:4, :16], weight_scale[:4, :2], weight_scale_2)
print(rec)
assert torch.isfinite(rec).all(), "Dequantized values contain NaN/Inf"
print(f" dequant: all finite ✓")
print(f" dequant range: [{rec.min().item():.4f}, {rec.max().item():.4f}]")
# Count what's quantized vs preserved across the whole model
quantized_weights = []
preserved = []
for name in weight_map:
if name.endswith(".weight"):
if name.replace(".weight", ".weight_scale") in weight_map:
quantized_weights.append(name)
else:
preserved.append(name)
print(f"\nWhole-model summary:")
print(f" Quantized .weight tensors: {len(quantized_weights):,}")
print(f" Preserved .weight tensors: {len(preserved):,}")
print(f" Total tensors in index: {len(weight_map):,}")
# Show a few preserved names to confirm the right things stayed in higher precision
print(f"\n Sample preserved tensors (should be lm_head, embed, gates, norms, etc.):")
for n in preserved[:10]:
print(f" {n}")
def vllm_check(model_dir: Path):
print("Loading model with vLLM... (requires Blackwell GPU + vLLM with NVFP4 support)")
from vllm import LLM, SamplingParams
llm = LLM(
model=str(model_dir),
trust_remote_code=True,
quantization="compressed-tensors",
dtype="auto",
tensor_parallel_size=8,
max_model_len=8192,
)
sampling = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=64)
prompts = [
"Write a short poem about quantization:",
"What is 17 * 23?",
"Explain MoE routing in one sentence.",
]
outputs = llm.generate(prompts, sampling)
for o in outputs:
print("=" * 60)
print("PROMPT:", o.prompt)
print("OUTPUT:", o.outputs[0].text)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("model_dir")
ap.add_argument("--vllm", action="store_true")
args = ap.parse_args()
model_dir = Path(args.model_dir)
tensor_only_check(model_dir)
if args.vllm:
print("\n" + "=" * 60)
vllm_check(model_dir)
if __name__ == "__main__":
main()