Files
deepseek-v4-quant/fp8_to_nvfp4_streaming.py

548 lines
21 KiB
Python

#!/usr/bin/env python3
"""Streaming FP8 → NVFP4 converter for DeepSeek V4 Pro (sgl-project FP8 repackage).
Path A: pure tensor-level conversion. No model loading via transformers, no
calibration. Reads FP8 safetensors shards, dequantizes per-block FP8 to FP32,
re-quantizes to NVFP4 (E2M1 packed in uint8 with FP8 E4M3 per-block scales and
an FP32 per-tensor global scale), and writes new shards.
Key behaviors:
- Joint global scale_2 across (gate_proj, up_proj) pairs of each expert,
required for vLLM fused MoE kernels.
- Preserves lm_head, embeddings, MoE router gates, norms, V4 indexer/scoring,
and mHC residual mixing weights at original precision.
- Streams shard-by-shard. Peak working memory is one tensor pair dequantized
to FP32 (a few hundred MB at most for the largest weights).
- Resumable per output shard.
NVFP4 format reference:
value = packed_fp4 * weight_scale * weight_scale_2
where:
packed_fp4: E2M1 in {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}, 2 per byte
weight_scale: FP8 E4M3, one per 16-element block
weight_scale_2: FP32 scalar per tensor, global
Usage:
python fp8_to_nvfp4_streaming.py \\
--src DeepSeek-V4-Pro-FP8 \\
--dst DeepSeek-V4-Pro-NVFP4-streaming \\
--workers 8
Optional:
--gpu N Use CUDA device N for the math (default: 0; -1 for CPU)
--shard-size-gb 5 Target output shard size
--dry-run Print what would be done; don't write
"""
import argparse
import json
import re
import shutil
import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
# ---------------------------------------------------------------------------
# Classification: which tensors do we quantize, which do we preserve?
# ---------------------------------------------------------------------------
# NVFP4-everything: only preserve 1D/non-weight tensors that can't be NVFP4
PRESERVE_REGEXES = [
r".*embed_tokens.*", # embeddings (kept in original precision)
r".*\.(mlp|ffn)\.gate(\.weight)?$", # MoE router (1D or small gate, not a GEMM weight)
r".*norm.*", # all norms (1D)
r".*indexer.*", # V4 CSA indexer (non-GEMM)
r".*scoring.*", # V4 scoring tensors
r".*attn_sink.*", # V4 attention sink (scalar/1D)
r".*compressor\.ape.*", # V4 compressor APE (1D)
r".*tid2eid.*", # V4 MoE token-to-expert mapping (1D)
r".*\.bias$", # any biases
r".*hc_attn_base.*", # V4 hyper-connection scalars
r".*hc_attn_fn.*",
r".*hc_ffn_base.*",
r".*hc_ffn_fn.*",
r".*hc_head_scale.*",
r".*compressor\.wgate\.weight$", # V4 compressor gate (small, preserve)
r".*compressor\.wkv\.weight$", # V4 compressor KV proj (small, preserve)
r".*indexer\.wq_b\.weight$", # V4 indexer projections (small, preserve)
r".*indexer\.wkv\.weight$",
r".*indexer\.compressor\.wkv\.weight$",
r".*indexer\.gate_proj\.weight$",
r".*indexer\.compressor\.wgate\.weight$",
r".*indexer\.q_b_proj\.weight$",
]
PRESERVE_RE = re.compile("|".join(f"(?:{p})" for p in PRESERVE_REGEXES))
# Identify expert pairs that need joint global scale
EXPERT_PAIR_RE = re.compile(r"(.*experts\.\d+)\.(w1|w3)\.weight$")
def is_preserve(name: str) -> bool:
return bool(PRESERVE_RE.match(name))
# ---------------------------------------------------------------------------
# FP8 dequantization (per-block)
# ---------------------------------------------------------------------------
def dequant_fp8_to_fp32(weight_fp8: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
"""Dequantize a per-block FP8 E4M3 weight to FP32 using its inverse-scale tensor.
DeepSeek convention: weight_scale_inv stores the dequant scale (multiply by it
to recover FP32). Block size is inferred from shape ratios — typically 128x128.
"""
assert weight_fp8.dim() == 2, f"Expected 2D weight, got shape {weight_fp8.shape}"
M, N = weight_fp8.shape
if scale_inv.dim() == 0:
# Per-tensor scale
return weight_fp8.float() * scale_inv.float()
if scale_inv.dim() == 1:
# Per-row or per-col — unusual for DeepSeek but handle it
if scale_inv.numel() == M:
return weight_fp8.float() * scale_inv.float().unsqueeze(1)
if scale_inv.numel() == N:
return weight_fp8.float() * scale_inv.float().unsqueeze(0)
raise ValueError(f"Cannot align 1D scale_inv {scale_inv.shape} to weight {weight_fp8.shape}")
# 2D block scaling
sm, sn = scale_inv.shape
bm = (M + sm - 1) // sm
bn = (N + sn - 1) // sn
scale_full = scale_inv.float().repeat_interleave(bm, dim=0).repeat_interleave(bn, dim=1)
scale_full = scale_full[:M, :N]
return weight_fp8.float() * scale_full
# ---------------------------------------------------------------------------
# NVFP4 quantization
# ---------------------------------------------------------------------------
FP4_E2M1_VALUES = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
# Boundaries between adjacent magnitudes (round-to-nearest with ties to even-ish)
FP4_BOUNDARIES = torch.tensor(
[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], dtype=torch.float32
)
FP4_MAX = 6.0
FP8_E4M3_MAX = 448.0
def round_to_fp4_e2m1_index(x: torch.Tensor) -> torch.Tensor:
"""Round x to nearest FP4 E2M1 representable, return 4-bit index in [0..15].
Index encoding: bit 3 = sign, bits 0..2 = magnitude index into FP4_E2M1_VALUES.
"""
sign = (x < 0).to(torch.uint8)
abs_x = x.abs().clamp_(max=FP4_MAX)
# searchsorted is fast on GPU; uses float32
boundaries = FP4_BOUNDARIES.to(x.device)
mag_idx = torch.searchsorted(boundaries, abs_x.contiguous()).to(torch.uint8)
return (sign << 3) | mag_idx
def quantize_to_nvfp4(
x_fp32: torch.Tensor,
scale_2: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize an FP32 weight to NVFP4 given a (possibly joint) global scale.
Args:
x_fp32: [M, N] FP32 tensor, N must be divisible by 16
scale_2: scalar FP32 tensor
Returns:
packed: [M, N//2] uint8, two FP4 values per byte (low nibble first)
weight_scale: [M, N//16] FP8 E4M3 per-block scales
"""
M, N = x_fp32.shape
if N % 16 != 0:
raise ValueError(f"NVFP4 requires N % 16 == 0; got {x_fp32.shape}")
# Per-block (16-element) amax
blocks = x_fp32.view(M, N // 16, 16)
block_amax = blocks.abs().amax(dim=-1) # [M, N//16]
# Per-block scale in FP32, then cast to FP8 E4M3 (this is the lossy step)
block_scale_fp32 = block_amax / (FP4_MAX * scale_2)
# Avoid zeros — produces NaN on dequant. Clamp tiny scales.
block_scale_fp32 = block_scale_fp32.clamp_(min=1e-30)
block_scale_fp8 = block_scale_fp32.to(torch.float8_e4m3fn)
# Recover the effective scale that the kernel will actually use
effective = scale_2 * block_scale_fp8.float() # [M, N//16]
# Quantize values: divide, clamp, round to E2M1
scaled = blocks / effective.unsqueeze(-1).clamp_(min=1e-30)
fp4_idx = round_to_fp4_e2m1_index(scaled) # [M, N//16, 16] uint8
fp4_idx = fp4_idx.view(M, N).contiguous()
# Pack two nibbles per byte: low = even-index element, high = odd-index element
low = fp4_idx[:, ::2]
high = fp4_idx[:, 1::2]
packed = (low | (high << 4)).to(torch.uint8)
return packed, block_scale_fp8
def compute_global_scale(*tensors_fp32: torch.Tensor) -> torch.Tensor:
"""Compute joint NVFP4 global scale_2 across one or more FP32 tensors.
scale_2 = amax / (FP4_MAX * FP8_E4M3_MAX)
"""
amax = torch.stack([t.abs().max() for t in tensors_fp32]).max()
scale_2 = amax / (FP4_MAX * FP8_E4M3_MAX)
# Avoid zero
return scale_2.clamp_(min=1e-30).float()
# ---------------------------------------------------------------------------
# Sharded output writer
# ---------------------------------------------------------------------------
class ShardedSafetensorsWriter:
"""Writes tensors to a sequence of safetensors shards, building an index map."""
def __init__(self, out_dir: Path, max_shard_bytes: int):
self.out_dir = out_dir
self.out_dir.mkdir(parents=True, exist_ok=True)
self.max_shard_bytes = max_shard_bytes
self.current = {} # name -> tensor (CPU)
self.current_bytes = 0
self.shard_idx = 0
self.weight_map: dict[str, str] = {} # name -> shard filename
self.shard_filenames: list[str] = []
def _flush(self):
if not self.current:
return
self.shard_idx += 1
# Use placeholder total; we'll rename at the end
fname = f"model-{self.shard_idx:05d}-of-PLACEHOLDER.safetensors"
path = self.out_dir / fname
save_file(self.current, str(path))
for name in self.current:
self.weight_map[name] = fname
self.shard_filenames.append(fname)
self.current.clear()
self.current_bytes = 0
def add(self, name: str, tensor: torch.Tensor):
# safetensors requires CPU tensors and contiguous
t = tensor.detach().cpu().contiguous()
size = t.numel() * t.element_size()
if self.current and self.current_bytes + size > self.max_shard_bytes:
self._flush()
self.current[name] = t
self.current_bytes += size
def close(self):
self._flush()
# Now rename shards to use proper of-N suffix
total = len(self.shard_filenames)
new_map = {}
for old_fname in self.shard_filenames:
idx = int(old_fname.split("-")[1])
new_fname = f"model-{idx:05d}-of-{total:05d}.safetensors"
(self.out_dir / old_fname).rename(self.out_dir / new_fname)
new_map[old_fname] = new_fname
# Patch weight_map
self.weight_map = {k: new_map[v] for k, v in self.weight_map.items()}
return self.weight_map
# ---------------------------------------------------------------------------
# Shard-level conversion plan
# ---------------------------------------------------------------------------
def build_plan(src_dir: Path):
"""Build the conversion plan from index.json.
Returns:
weight_map: name -> shard filename
shard_to_names: shard filename -> list of names in that shard
expert_pair_groups: list of (group_name, name_w1, name_w3)
For each expert, the gate_proj/up_proj pair gets a shared scale_2.
solo_quantize: list of names to quantize independently
preserve: list of names to copy unchanged
"""
with open(src_dir / "model.safetensors.index.json") as f:
index = json.load(f)
weight_map = index["weight_map"]
shard_to_names = defaultdict(list)
for name, fn in weight_map.items():
shard_to_names[fn].append(name)
# Gather all weight tensor names (those with .weight suffix)
all_weights = [n for n in weight_map if n.endswith(".weight")]
# Identify expert pairs
expert_pairs = defaultdict(dict) # base -> {"gate_proj": name, "up_proj": name}
for n in all_weights:
m = EXPERT_PAIR_RE.match(n)
if m:
base, kind = m.group(1), m.group(2)
expert_pairs[base][kind] = n
paired_names = set()
expert_pair_groups = []
for base, parts in expert_pairs.items():
if "w1" in parts and "w3" in parts:
expert_pair_groups.append((base, parts["w1"], parts["w3"]))
paired_names.add(parts["w1"])
paired_names.add(parts["w3"])
# Classify everything else
solo_quantize = []
preserve = []
scale_companions = [] # .scale tensors that get consumed during dequant
for n in weight_map:
if n.endswith(".scale") and n.replace(".scale", ".weight") in weight_map:
scale_companions.append(n)
continue
if n in paired_names:
continue
if is_preserve(n):
preserve.append(n)
continue
# Anything else with .weight gets quantized solo, otherwise preserved
if n.endswith(".weight"):
solo_quantize.append(n)
else:
preserve.append(n)
return {
"weight_map": weight_map,
"shard_to_names": dict(shard_to_names),
"expert_pair_groups": expert_pair_groups,
"solo_quantize": solo_quantize,
"preserve": preserve,
"scale_companions": scale_companions,
}
# ---------------------------------------------------------------------------
# Tensor loading helpers
# ---------------------------------------------------------------------------
class ShardCache:
"""Lazy per-shard safe_open cache so we don't re-open shards repeatedly."""
def __init__(self, src_dir: Path, max_open: int = 4):
self.src_dir = src_dir
self.max_open = max_open
self.handles: dict[str, "safe_open"] = {}
def get(self, shard_fname: str):
if shard_fname in self.handles:
return self.handles[shard_fname]
if len(self.handles) >= self.max_open:
# Drop one
old_fn = next(iter(self.handles))
self.handles[old_fn].__exit__(None, None, None)
del self.handles[old_fn]
h = safe_open(self.src_dir / shard_fname, framework="pt")
h.__enter__()
self.handles[shard_fname] = h
return h
def close(self):
for h in self.handles.values():
h.__exit__(None, None, None)
self.handles.clear()
def load_weight_and_scale(cache: ShardCache, weight_map, name):
"""Load an FP8 weight with its scale companion (if any)."""
weight = cache.get(weight_map[name]).get_tensor(name)
scale_name = name.replace(".weight", ".scale")
scale = None
if scale_name in weight_map:
try:
scale = cache.get(weight_map[scale_name]).get_tensor(scale_name)
except Exception:
# Scale listed in index but not in shard (BF16 weights have no scale)
pass
return weight, scale
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--src", required=True, help="Source FP8 model directory")
ap.add_argument("--dst", required=True, help="Output NVFP4 model directory")
ap.add_argument("--gpu", type=int, default=0, help="CUDA device, -1 for CPU")
ap.add_argument("--shard-size-gb", type=float, default=5.0)
ap.add_argument("--workers", type=int, default=4,
help="Concurrent tensor-conversion workers (lots of small tensors benefit; "
"actual GPU compute is serialized by torch)")
ap.add_argument("--dry-run", action="store_true")
args = ap.parse_args()
src = Path(args.src).resolve()
dst = Path(args.dst).resolve()
if not (src / "model.safetensors.index.json").exists():
sys.exit(f"No index.json at {src}")
device = torch.device(f"cuda:{args.gpu}" if args.gpu >= 0 and torch.cuda.is_available() else "cpu")
print(f"Compute device: {device}")
# Move FP4_BOUNDARIES to device once
global FP4_BOUNDARIES
FP4_BOUNDARIES = FP4_BOUNDARIES.to(device)
print("Building conversion plan...")
plan = build_plan(src)
n_pairs = len(plan["expert_pair_groups"])
n_solo = len(plan["solo_quantize"])
n_preserve = len(plan["preserve"])
n_scales = len(plan["scale_companions"])
print(f" Expert pair groups (joint scale_2): {n_pairs:,}")
print(f" Solo quantize tensors: {n_solo:,}")
print(f" Preserved tensors: {n_preserve:,}")
print(f" Scale companions consumed: {n_scales:,}")
if args.dry_run:
print("\nDry run — exiting before any writes.")
return
dst.mkdir(parents=True, exist_ok=True)
cache = ShardCache(src, max_open=8)
writer = ShardedSafetensorsWriter(dst, max_shard_bytes=int(args.shard_size_gb * 1024**3))
weight_map = plan["weight_map"]
t_start = time.time()
# ------------------------------------------------------------------
# 1. Preserved tensors — copy unchanged
# ------------------------------------------------------------------
for name in tqdm(plan["preserve"], desc="Preserve", unit="tensor"):
t = cache.get(weight_map[name]).get_tensor(name)
writer.add(name, t)
# ------------------------------------------------------------------
# 2. Expert pairs — joint scale_2 across (gate_proj, up_proj)
# ------------------------------------------------------------------
for base, name_w1, name_w3 in tqdm(plan["expert_pair_groups"], desc="Expert pairs", unit="pair"):
w1_fp8, s1 = load_weight_and_scale(cache, weight_map, name_w1)
w3_fp8, s3 = load_weight_and_scale(cache, weight_map, name_w3)
with torch.no_grad():
w1 = dequant_fp8_to_fp32(w1_fp8.to(device), s1.to(device)) if s1 is not None else w1_fp8.float().to(device)
w3 = dequant_fp8_to_fp32(w3_fp8.to(device), s3.to(device)) if s3 is not None else w3_fp8.float().to(device)
scale_2 = compute_global_scale(w1, w3)
packed1, blk1 = quantize_to_nvfp4(w1, scale_2)
packed3, blk3 = quantize_to_nvfp4(w3, scale_2)
writer.add(name_w1, packed1)
writer.add(name_w1.replace(".weight", ".weight_scale"), blk1)
writer.add(name_w1.replace(".weight", ".weight_scale_2"), scale_2)
writer.add(name_w3, packed3)
writer.add(name_w3.replace(".weight", ".weight_scale"), blk3)
writer.add(name_w3.replace(".weight", ".weight_scale_2"), scale_2)
# ------------------------------------------------------------------
# 3. Solo quantize tensors — independent scale_2 per tensor
# ------------------------------------------------------------------
for name in tqdm(plan["solo_quantize"], desc="Solo quantize", unit="tensor"):
w_fp8, s = load_weight_and_scale(cache, weight_map, name)
with torch.no_grad():
if s is not None:
w = dequant_fp8_to_fp32(w_fp8.to(device), s.to(device))
else:
# Already non-FP8 (e.g. BF16), just upcast
w = w_fp8.float().to(device)
scale_2 = compute_global_scale(w)
packed, blk = quantize_to_nvfp4(w, scale_2)
writer.add(name, packed)
writer.add(name.replace(".weight", ".weight_scale"), blk)
writer.add(name.replace(".weight", ".weight_scale_2"), scale_2)
# Finalize shards & index
final_weight_map = writer.close()
cache.close()
# ------------------------------------------------------------------
# 4. Write model.safetensors.index.json
# ------------------------------------------------------------------
total_size = sum(
(dst / fn).stat().st_size for fn in set(final_weight_map.values())
)
new_index = {
"metadata": {"total_size": total_size},
"weight_map": final_weight_map,
}
with open(dst / "model.safetensors.index.json", "w") as f:
json.dump(new_index, f, indent=2)
# ------------------------------------------------------------------
# 5. Copy non-tensor files (config, tokenizer, etc.)
# ------------------------------------------------------------------
for fname in src.iterdir():
if fname.is_dir():
# encoding/, inference/, assets/ — copy whole tree
dst_sub = dst / fname.name
if not dst_sub.exists():
shutil.copytree(fname, dst_sub)
continue
if fname.suffix == ".safetensors":
continue
if fname.name == "model.safetensors.index.json":
continue
shutil.copy2(fname, dst / fname.name)
# ------------------------------------------------------------------
# 6. Patch config.json with quantization metadata so loaders know
# ------------------------------------------------------------------
cfg_path = dst / "config.json"
if cfg_path.exists():
with open(cfg_path) as f:
cfg = json.load(f)
cfg["quantization_config"] = {
"quant_method": "compressed-tensors",
"format": "nvfp4-pack-quantized",
"config_groups": {
"group_0": {
"targets": ["Linear"],
"weights": {
"num_bits": 4,
"type": "float",
"strategy": "tensor_group",
"group_size": 16,
"symmetric": True,
},
}
},
"ignore": PRESERVE_REGEXES,
}
with open(cfg_path, "w") as f:
json.dump(cfg, f, indent=2)
elapsed = time.time() - t_start
print(f"\nDone in {elapsed/3600:.2f}h")
print(f"Output: {dst}")
print(f"Total size: {total_size/1024**3:.1f} GB across {len(set(final_weight_map.values()))} shards")
if __name__ == "__main__":
main()