179 lines
6.3 KiB
Python
179 lines
6.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Sanity check an NVFP4 DeepSeek V4 Pro checkpoint.
|
|
|
|
Two modes:
|
|
|
|
1) --tensor-only (default): no model loading. Just inspects the safetensors
|
|
shards: confirms NVFP4 packing structure (uint8 weight + FP8 weight_scale
|
|
+ FP32 weight_scale_2), checks for NaN/Inf in scales, samples a few
|
|
dequantizations to confirm they look plausible.
|
|
|
|
2) --vllm: tries to load the model with vLLM and generate a few tokens.
|
|
Requires vLLM with NVFP4 support (SM100+ Blackwell GPU).
|
|
|
|
Usage:
|
|
python verify_nvfp4.py DeepSeek-V4-Pro-NVFP4-streaming
|
|
python verify_nvfp4.py DeepSeek-V4-Pro-NVFP4-streaming --vllm
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from safetensors import safe_open
|
|
|
|
|
|
FP4_E2M1_VALUES = torch.tensor(
|
|
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
|
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
|
|
def unpack_fp4(packed: torch.Tensor) -> torch.Tensor:
|
|
"""Reverse the (low | high<<4) byte pack into a [M, N] tensor of FP4 indices."""
|
|
low = packed & 0x0F
|
|
high = (packed >> 4) & 0x0F
|
|
M, N_half = packed.shape
|
|
out = torch.empty(M, N_half * 2, dtype=torch.uint8)
|
|
out[:, ::2] = low
|
|
out[:, 1::2] = high
|
|
return out
|
|
|
|
|
|
def dequant_nvfp4(packed_uint8, weight_scale_fp8, weight_scale_2_fp32):
|
|
"""Reconstruct FP32 values from NVFP4 storage."""
|
|
fp4_idx = unpack_fp4(packed_uint8)
|
|
values = FP4_E2M1_VALUES[fp4_idx.long()] # [M, N]
|
|
M, N = values.shape
|
|
# Per-block scale broadcast back over 16 elements
|
|
scale_blocks = weight_scale_fp8.float() # [M, N//16]
|
|
scale_per_elem = scale_blocks.unsqueeze(-1).expand(-1, -1, 16).reshape(M, N)
|
|
return values * scale_per_elem * weight_scale_2_fp32.float()
|
|
|
|
|
|
def tensor_only_check(model_dir: Path):
|
|
index_path = model_dir / "model.safetensors.index.json"
|
|
if not index_path.exists():
|
|
sys.exit(f"No index.json at {model_dir}")
|
|
with open(index_path) as f:
|
|
index = json.load(f)
|
|
weight_map = index["weight_map"]
|
|
|
|
# Find one quantized weight to sample
|
|
sample = None
|
|
for name, fn in weight_map.items():
|
|
if name.endswith(".weight") and (name.replace(".weight", ".weight_scale") in weight_map):
|
|
sample = name
|
|
break
|
|
if not sample:
|
|
sys.exit("Couldn't find an NVFP4-quantized weight (expected *.weight_scale companion).")
|
|
|
|
print(f"Sampling: {sample}")
|
|
shard_fn = weight_map[sample]
|
|
scale_name = sample.replace(".weight", ".weight_scale")
|
|
scale_2_name = sample.replace(".weight", ".weight_scale_2")
|
|
scale_shard = weight_map[scale_name]
|
|
scale_2_shard = weight_map[scale_2_name]
|
|
|
|
def open_get(fn, name):
|
|
with safe_open(model_dir / fn, framework="pt") as f:
|
|
return f.get_tensor(name)
|
|
|
|
packed = open_get(shard_fn, sample)
|
|
weight_scale = open_get(scale_shard, scale_name)
|
|
weight_scale_2 = open_get(scale_2_shard, scale_2_name)
|
|
|
|
print(f" packed: shape={tuple(packed.shape)} dtype={packed.dtype}")
|
|
print(f" weight_scale: shape={tuple(weight_scale.shape)} dtype={weight_scale.dtype}")
|
|
print(f" weight_scale_2: shape={tuple(weight_scale_2.shape)} dtype={weight_scale_2.dtype} "
|
|
f"value={weight_scale_2.float().item():.6e}")
|
|
|
|
# Structural assertions
|
|
M = packed.shape[0]
|
|
assert packed.dtype == torch.uint8, f"packed should be uint8, got {packed.dtype}"
|
|
assert weight_scale.dtype == torch.float8_e4m3fn, \
|
|
f"weight_scale should be FP8 E4M3, got {weight_scale.dtype}"
|
|
assert weight_scale.shape == (M, packed.shape[1] * 2 // 16), \
|
|
f"weight_scale shape {weight_scale.shape} doesn't match expected (M, N/16)"
|
|
|
|
# Check for NaN/Inf in scales
|
|
s_fp32 = weight_scale.float()
|
|
assert torch.isfinite(s_fp32).all(), "weight_scale contains NaN/Inf"
|
|
assert torch.isfinite(weight_scale_2.float()).all(), "weight_scale_2 is NaN/Inf"
|
|
print(f" scales: all finite ✓")
|
|
print(f" weight_scale stats: min={s_fp32.min().item():.3e} max={s_fp32.max().item():.3e} "
|
|
f"mean={s_fp32.mean().item():.3e}")
|
|
|
|
# Spot-check dequantization
|
|
print("\nDequantizing first 4x32 block for visual check:")
|
|
rec = dequant_nvfp4(packed[:4, :16], weight_scale[:4, :2], weight_scale_2)
|
|
print(rec)
|
|
assert torch.isfinite(rec).all(), "Dequantized values contain NaN/Inf"
|
|
print(f" dequant: all finite ✓")
|
|
print(f" dequant range: [{rec.min().item():.4f}, {rec.max().item():.4f}]")
|
|
|
|
# Count what's quantized vs preserved across the whole model
|
|
quantized_weights = []
|
|
preserved = []
|
|
for name in weight_map:
|
|
if name.endswith(".weight"):
|
|
if name.replace(".weight", ".weight_scale") in weight_map:
|
|
quantized_weights.append(name)
|
|
else:
|
|
preserved.append(name)
|
|
|
|
print(f"\nWhole-model summary:")
|
|
print(f" Quantized .weight tensors: {len(quantized_weights):,}")
|
|
print(f" Preserved .weight tensors: {len(preserved):,}")
|
|
print(f" Total tensors in index: {len(weight_map):,}")
|
|
|
|
# Show a few preserved names to confirm the right things stayed in higher precision
|
|
print(f"\n Sample preserved tensors (should be lm_head, embed, gates, norms, etc.):")
|
|
for n in preserved[:10]:
|
|
print(f" {n}")
|
|
|
|
|
|
def vllm_check(model_dir: Path):
|
|
print("Loading model with vLLM... (requires Blackwell GPU + vLLM with NVFP4 support)")
|
|
from vllm import LLM, SamplingParams
|
|
|
|
llm = LLM(
|
|
model=str(model_dir),
|
|
trust_remote_code=True,
|
|
quantization="compressed-tensors",
|
|
dtype="auto",
|
|
tensor_parallel_size=8,
|
|
max_model_len=8192,
|
|
)
|
|
sampling = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=64)
|
|
|
|
prompts = [
|
|
"Write a short poem about quantization:",
|
|
"What is 17 * 23?",
|
|
"Explain MoE routing in one sentence.",
|
|
]
|
|
outputs = llm.generate(prompts, sampling)
|
|
for o in outputs:
|
|
print("=" * 60)
|
|
print("PROMPT:", o.prompt)
|
|
print("OUTPUT:", o.outputs[0].text)
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("model_dir")
|
|
ap.add_argument("--vllm", action="store_true")
|
|
args = ap.parse_args()
|
|
model_dir = Path(args.model_dir)
|
|
|
|
tensor_only_check(model_dir)
|
|
if args.vllm:
|
|
print("\n" + "=" * 60)
|
|
vllm_check(model_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |