548 lines
21 KiB
Python
548 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""Streaming FP8 → NVFP4 converter for DeepSeek V4 Pro (sgl-project FP8 repackage).
|
|
|
|
Path A: pure tensor-level conversion. No model loading via transformers, no
|
|
calibration. Reads FP8 safetensors shards, dequantizes per-block FP8 to FP32,
|
|
re-quantizes to NVFP4 (E2M1 packed in uint8 with FP8 E4M3 per-block scales and
|
|
an FP32 per-tensor global scale), and writes new shards.
|
|
|
|
Key behaviors:
|
|
- Joint global scale_2 across (gate_proj, up_proj) pairs of each expert,
|
|
required for vLLM fused MoE kernels.
|
|
- Preserves lm_head, embeddings, MoE router gates, norms, V4 indexer/scoring,
|
|
and mHC residual mixing weights at original precision.
|
|
- Streams shard-by-shard. Peak working memory is one tensor pair dequantized
|
|
to FP32 (a few hundred MB at most for the largest weights).
|
|
- Resumable per output shard.
|
|
|
|
NVFP4 format reference:
|
|
value = packed_fp4 * weight_scale * weight_scale_2
|
|
where:
|
|
packed_fp4: E2M1 in {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}, 2 per byte
|
|
weight_scale: FP8 E4M3, one per 16-element block
|
|
weight_scale_2: FP32 scalar per tensor, global
|
|
|
|
Usage:
|
|
python fp8_to_nvfp4_streaming.py \\
|
|
--src DeepSeek-V4-Pro-FP8 \\
|
|
--dst DeepSeek-V4-Pro-NVFP4-streaming \\
|
|
--workers 8
|
|
|
|
Optional:
|
|
--gpu N Use CUDA device N for the math (default: 0; -1 for CPU)
|
|
--shard-size-gb 5 Target output shard size
|
|
--dry-run Print what would be done; don't write
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import time
|
|
from collections import defaultdict
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
from tqdm import tqdm
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Classification: which tensors do we quantize, which do we preserve?
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# NVFP4-everything: only preserve 1D/non-weight tensors that can't be NVFP4
|
|
PRESERVE_REGEXES = [
|
|
r".*embed_tokens.*", # embeddings (kept in original precision)
|
|
r".*\.(mlp|ffn)\.gate(\.weight)?$", # MoE router (1D or small gate, not a GEMM weight)
|
|
r".*norm.*", # all norms (1D)
|
|
r".*indexer.*", # V4 CSA indexer (non-GEMM)
|
|
r".*scoring.*", # V4 scoring tensors
|
|
r".*attn_sink.*", # V4 attention sink (scalar/1D)
|
|
r".*compressor\.ape.*", # V4 compressor APE (1D)
|
|
r".*tid2eid.*", # V4 MoE token-to-expert mapping (1D)
|
|
r".*\.bias$", # any biases
|
|
r".*hc_attn_base.*", # V4 hyper-connection scalars
|
|
r".*hc_attn_fn.*",
|
|
r".*hc_ffn_base.*",
|
|
r".*hc_ffn_fn.*",
|
|
r".*hc_head_scale.*",
|
|
r".*compressor\.wgate\.weight$", # V4 compressor gate (small, preserve)
|
|
r".*compressor\.wkv\.weight$", # V4 compressor KV proj (small, preserve)
|
|
r".*indexer\.wq_b\.weight$", # V4 indexer projections (small, preserve)
|
|
r".*indexer\.wkv\.weight$",
|
|
r".*indexer\.compressor\.wkv\.weight$",
|
|
r".*indexer\.gate_proj\.weight$",
|
|
r".*indexer\.compressor\.wgate\.weight$",
|
|
r".*indexer\.q_b_proj\.weight$",
|
|
]
|
|
PRESERVE_RE = re.compile("|".join(f"(?:{p})" for p in PRESERVE_REGEXES))
|
|
|
|
# Identify expert pairs that need joint global scale
|
|
EXPERT_PAIR_RE = re.compile(r"(.*experts\.\d+)\.(w1|w3)\.weight$")
|
|
|
|
|
|
def is_preserve(name: str) -> bool:
|
|
return bool(PRESERVE_RE.match(name))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FP8 dequantization (per-block)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def dequant_fp8_to_fp32(weight_fp8: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
|
|
"""Dequantize a per-block FP8 E4M3 weight to FP32 using its inverse-scale tensor.
|
|
|
|
DeepSeek convention: weight_scale_inv stores the dequant scale (multiply by it
|
|
to recover FP32). Block size is inferred from shape ratios — typically 128x128.
|
|
"""
|
|
assert weight_fp8.dim() == 2, f"Expected 2D weight, got shape {weight_fp8.shape}"
|
|
M, N = weight_fp8.shape
|
|
|
|
if scale_inv.dim() == 0:
|
|
# Per-tensor scale
|
|
return weight_fp8.float() * scale_inv.float()
|
|
|
|
if scale_inv.dim() == 1:
|
|
# Per-row or per-col — unusual for DeepSeek but handle it
|
|
if scale_inv.numel() == M:
|
|
return weight_fp8.float() * scale_inv.float().unsqueeze(1)
|
|
if scale_inv.numel() == N:
|
|
return weight_fp8.float() * scale_inv.float().unsqueeze(0)
|
|
raise ValueError(f"Cannot align 1D scale_inv {scale_inv.shape} to weight {weight_fp8.shape}")
|
|
|
|
# 2D block scaling
|
|
sm, sn = scale_inv.shape
|
|
bm = (M + sm - 1) // sm
|
|
bn = (N + sn - 1) // sn
|
|
scale_full = scale_inv.float().repeat_interleave(bm, dim=0).repeat_interleave(bn, dim=1)
|
|
scale_full = scale_full[:M, :N]
|
|
return weight_fp8.float() * scale_full
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# NVFP4 quantization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
FP4_E2M1_VALUES = torch.tensor(
|
|
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
|
)
|
|
# Boundaries between adjacent magnitudes (round-to-nearest with ties to even-ish)
|
|
FP4_BOUNDARIES = torch.tensor(
|
|
[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], dtype=torch.float32
|
|
)
|
|
FP4_MAX = 6.0
|
|
FP8_E4M3_MAX = 448.0
|
|
|
|
|
|
def round_to_fp4_e2m1_index(x: torch.Tensor) -> torch.Tensor:
|
|
"""Round x to nearest FP4 E2M1 representable, return 4-bit index in [0..15].
|
|
|
|
Index encoding: bit 3 = sign, bits 0..2 = magnitude index into FP4_E2M1_VALUES.
|
|
"""
|
|
sign = (x < 0).to(torch.uint8)
|
|
abs_x = x.abs().clamp_(max=FP4_MAX)
|
|
# searchsorted is fast on GPU; uses float32
|
|
boundaries = FP4_BOUNDARIES.to(x.device)
|
|
mag_idx = torch.searchsorted(boundaries, abs_x.contiguous()).to(torch.uint8)
|
|
return (sign << 3) | mag_idx
|
|
|
|
|
|
def quantize_to_nvfp4(
|
|
x_fp32: torch.Tensor,
|
|
scale_2: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Quantize an FP32 weight to NVFP4 given a (possibly joint) global scale.
|
|
|
|
Args:
|
|
x_fp32: [M, N] FP32 tensor, N must be divisible by 16
|
|
scale_2: scalar FP32 tensor
|
|
|
|
Returns:
|
|
packed: [M, N//2] uint8, two FP4 values per byte (low nibble first)
|
|
weight_scale: [M, N//16] FP8 E4M3 per-block scales
|
|
"""
|
|
M, N = x_fp32.shape
|
|
if N % 16 != 0:
|
|
raise ValueError(f"NVFP4 requires N % 16 == 0; got {x_fp32.shape}")
|
|
|
|
# Per-block (16-element) amax
|
|
blocks = x_fp32.view(M, N // 16, 16)
|
|
block_amax = blocks.abs().amax(dim=-1) # [M, N//16]
|
|
|
|
# Per-block scale in FP32, then cast to FP8 E4M3 (this is the lossy step)
|
|
block_scale_fp32 = block_amax / (FP4_MAX * scale_2)
|
|
# Avoid zeros — produces NaN on dequant. Clamp tiny scales.
|
|
block_scale_fp32 = block_scale_fp32.clamp_(min=1e-30)
|
|
block_scale_fp8 = block_scale_fp32.to(torch.float8_e4m3fn)
|
|
|
|
# Recover the effective scale that the kernel will actually use
|
|
effective = scale_2 * block_scale_fp8.float() # [M, N//16]
|
|
|
|
# Quantize values: divide, clamp, round to E2M1
|
|
scaled = blocks / effective.unsqueeze(-1).clamp_(min=1e-30)
|
|
fp4_idx = round_to_fp4_e2m1_index(scaled) # [M, N//16, 16] uint8
|
|
fp4_idx = fp4_idx.view(M, N).contiguous()
|
|
|
|
# Pack two nibbles per byte: low = even-index element, high = odd-index element
|
|
low = fp4_idx[:, ::2]
|
|
high = fp4_idx[:, 1::2]
|
|
packed = (low | (high << 4)).to(torch.uint8)
|
|
|
|
return packed, block_scale_fp8
|
|
|
|
|
|
def compute_global_scale(*tensors_fp32: torch.Tensor) -> torch.Tensor:
|
|
"""Compute joint NVFP4 global scale_2 across one or more FP32 tensors.
|
|
|
|
scale_2 = amax / (FP4_MAX * FP8_E4M3_MAX)
|
|
"""
|
|
amax = torch.stack([t.abs().max() for t in tensors_fp32]).max()
|
|
scale_2 = amax / (FP4_MAX * FP8_E4M3_MAX)
|
|
# Avoid zero
|
|
return scale_2.clamp_(min=1e-30).float()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sharded output writer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class ShardedSafetensorsWriter:
|
|
"""Writes tensors to a sequence of safetensors shards, building an index map."""
|
|
|
|
def __init__(self, out_dir: Path, max_shard_bytes: int):
|
|
self.out_dir = out_dir
|
|
self.out_dir.mkdir(parents=True, exist_ok=True)
|
|
self.max_shard_bytes = max_shard_bytes
|
|
self.current = {} # name -> tensor (CPU)
|
|
self.current_bytes = 0
|
|
self.shard_idx = 0
|
|
self.weight_map: dict[str, str] = {} # name -> shard filename
|
|
self.shard_filenames: list[str] = []
|
|
|
|
def _flush(self):
|
|
if not self.current:
|
|
return
|
|
self.shard_idx += 1
|
|
# Use placeholder total; we'll rename at the end
|
|
fname = f"model-{self.shard_idx:05d}-of-PLACEHOLDER.safetensors"
|
|
path = self.out_dir / fname
|
|
save_file(self.current, str(path))
|
|
for name in self.current:
|
|
self.weight_map[name] = fname
|
|
self.shard_filenames.append(fname)
|
|
self.current.clear()
|
|
self.current_bytes = 0
|
|
|
|
def add(self, name: str, tensor: torch.Tensor):
|
|
# safetensors requires CPU tensors and contiguous
|
|
t = tensor.detach().cpu().contiguous()
|
|
size = t.numel() * t.element_size()
|
|
if self.current and self.current_bytes + size > self.max_shard_bytes:
|
|
self._flush()
|
|
self.current[name] = t
|
|
self.current_bytes += size
|
|
|
|
def close(self):
|
|
self._flush()
|
|
# Now rename shards to use proper of-N suffix
|
|
total = len(self.shard_filenames)
|
|
new_map = {}
|
|
for old_fname in self.shard_filenames:
|
|
idx = int(old_fname.split("-")[1])
|
|
new_fname = f"model-{idx:05d}-of-{total:05d}.safetensors"
|
|
(self.out_dir / old_fname).rename(self.out_dir / new_fname)
|
|
new_map[old_fname] = new_fname
|
|
|
|
# Patch weight_map
|
|
self.weight_map = {k: new_map[v] for k, v in self.weight_map.items()}
|
|
return self.weight_map
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shard-level conversion plan
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def build_plan(src_dir: Path):
|
|
"""Build the conversion plan from index.json.
|
|
|
|
Returns:
|
|
weight_map: name -> shard filename
|
|
shard_to_names: shard filename -> list of names in that shard
|
|
expert_pair_groups: list of (group_name, name_w1, name_w3)
|
|
For each expert, the gate_proj/up_proj pair gets a shared scale_2.
|
|
solo_quantize: list of names to quantize independently
|
|
preserve: list of names to copy unchanged
|
|
"""
|
|
with open(src_dir / "model.safetensors.index.json") as f:
|
|
index = json.load(f)
|
|
weight_map = index["weight_map"]
|
|
|
|
shard_to_names = defaultdict(list)
|
|
for name, fn in weight_map.items():
|
|
shard_to_names[fn].append(name)
|
|
|
|
# Gather all weight tensor names (those with .weight suffix)
|
|
all_weights = [n for n in weight_map if n.endswith(".weight")]
|
|
|
|
# Identify expert pairs
|
|
expert_pairs = defaultdict(dict) # base -> {"gate_proj": name, "up_proj": name}
|
|
for n in all_weights:
|
|
m = EXPERT_PAIR_RE.match(n)
|
|
if m:
|
|
base, kind = m.group(1), m.group(2)
|
|
expert_pairs[base][kind] = n
|
|
|
|
paired_names = set()
|
|
expert_pair_groups = []
|
|
for base, parts in expert_pairs.items():
|
|
if "w1" in parts and "w3" in parts:
|
|
expert_pair_groups.append((base, parts["w1"], parts["w3"]))
|
|
paired_names.add(parts["w1"])
|
|
paired_names.add(parts["w3"])
|
|
|
|
# Classify everything else
|
|
solo_quantize = []
|
|
preserve = []
|
|
scale_companions = [] # .scale tensors that get consumed during dequant
|
|
|
|
for n in weight_map:
|
|
if n.endswith(".scale") and n.replace(".scale", ".weight") in weight_map:
|
|
scale_companions.append(n)
|
|
continue
|
|
if n in paired_names:
|
|
continue
|
|
if is_preserve(n):
|
|
preserve.append(n)
|
|
continue
|
|
# Anything else with .weight gets quantized solo, otherwise preserved
|
|
if n.endswith(".weight"):
|
|
solo_quantize.append(n)
|
|
else:
|
|
preserve.append(n)
|
|
|
|
return {
|
|
"weight_map": weight_map,
|
|
"shard_to_names": dict(shard_to_names),
|
|
"expert_pair_groups": expert_pair_groups,
|
|
"solo_quantize": solo_quantize,
|
|
"preserve": preserve,
|
|
"scale_companions": scale_companions,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tensor loading helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class ShardCache:
|
|
"""Lazy per-shard safe_open cache so we don't re-open shards repeatedly."""
|
|
|
|
def __init__(self, src_dir: Path, max_open: int = 4):
|
|
self.src_dir = src_dir
|
|
self.max_open = max_open
|
|
self.handles: dict[str, "safe_open"] = {}
|
|
|
|
def get(self, shard_fname: str):
|
|
if shard_fname in self.handles:
|
|
return self.handles[shard_fname]
|
|
if len(self.handles) >= self.max_open:
|
|
# Drop one
|
|
old_fn = next(iter(self.handles))
|
|
self.handles[old_fn].__exit__(None, None, None)
|
|
del self.handles[old_fn]
|
|
h = safe_open(self.src_dir / shard_fname, framework="pt")
|
|
h.__enter__()
|
|
self.handles[shard_fname] = h
|
|
return h
|
|
|
|
def close(self):
|
|
for h in self.handles.values():
|
|
h.__exit__(None, None, None)
|
|
self.handles.clear()
|
|
|
|
|
|
def load_weight_and_scale(cache: ShardCache, weight_map, name):
|
|
"""Load an FP8 weight with its scale companion (if any)."""
|
|
weight = cache.get(weight_map[name]).get_tensor(name)
|
|
scale_name = name.replace(".weight", ".scale")
|
|
scale = None
|
|
if scale_name in weight_map:
|
|
try:
|
|
scale = cache.get(weight_map[scale_name]).get_tensor(scale_name)
|
|
except Exception:
|
|
# Scale listed in index but not in shard (BF16 weights have no scale)
|
|
pass
|
|
return weight, scale
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--src", required=True, help="Source FP8 model directory")
|
|
ap.add_argument("--dst", required=True, help="Output NVFP4 model directory")
|
|
ap.add_argument("--gpu", type=int, default=0, help="CUDA device, -1 for CPU")
|
|
ap.add_argument("--shard-size-gb", type=float, default=5.0)
|
|
ap.add_argument("--workers", type=int, default=4,
|
|
help="Concurrent tensor-conversion workers (lots of small tensors benefit; "
|
|
"actual GPU compute is serialized by torch)")
|
|
ap.add_argument("--dry-run", action="store_true")
|
|
args = ap.parse_args()
|
|
|
|
src = Path(args.src).resolve()
|
|
dst = Path(args.dst).resolve()
|
|
if not (src / "model.safetensors.index.json").exists():
|
|
sys.exit(f"No index.json at {src}")
|
|
|
|
device = torch.device(f"cuda:{args.gpu}" if args.gpu >= 0 and torch.cuda.is_available() else "cpu")
|
|
print(f"Compute device: {device}")
|
|
|
|
# Move FP4_BOUNDARIES to device once
|
|
global FP4_BOUNDARIES
|
|
FP4_BOUNDARIES = FP4_BOUNDARIES.to(device)
|
|
|
|
print("Building conversion plan...")
|
|
plan = build_plan(src)
|
|
n_pairs = len(plan["expert_pair_groups"])
|
|
n_solo = len(plan["solo_quantize"])
|
|
n_preserve = len(plan["preserve"])
|
|
n_scales = len(plan["scale_companions"])
|
|
print(f" Expert pair groups (joint scale_2): {n_pairs:,}")
|
|
print(f" Solo quantize tensors: {n_solo:,}")
|
|
print(f" Preserved tensors: {n_preserve:,}")
|
|
print(f" Scale companions consumed: {n_scales:,}")
|
|
|
|
if args.dry_run:
|
|
print("\nDry run — exiting before any writes.")
|
|
return
|
|
|
|
dst.mkdir(parents=True, exist_ok=True)
|
|
cache = ShardCache(src, max_open=8)
|
|
writer = ShardedSafetensorsWriter(dst, max_shard_bytes=int(args.shard_size_gb * 1024**3))
|
|
|
|
weight_map = plan["weight_map"]
|
|
t_start = time.time()
|
|
|
|
# ------------------------------------------------------------------
|
|
# 1. Preserved tensors — copy unchanged
|
|
# ------------------------------------------------------------------
|
|
for name in tqdm(plan["preserve"], desc="Preserve", unit="tensor"):
|
|
t = cache.get(weight_map[name]).get_tensor(name)
|
|
writer.add(name, t)
|
|
|
|
# ------------------------------------------------------------------
|
|
# 2. Expert pairs — joint scale_2 across (gate_proj, up_proj)
|
|
# ------------------------------------------------------------------
|
|
for base, name_w1, name_w3 in tqdm(plan["expert_pair_groups"], desc="Expert pairs", unit="pair"):
|
|
w1_fp8, s1 = load_weight_and_scale(cache, weight_map, name_w1)
|
|
w3_fp8, s3 = load_weight_and_scale(cache, weight_map, name_w3)
|
|
|
|
with torch.no_grad():
|
|
w1 = dequant_fp8_to_fp32(w1_fp8.to(device), s1.to(device)) if s1 is not None else w1_fp8.float().to(device)
|
|
w3 = dequant_fp8_to_fp32(w3_fp8.to(device), s3.to(device)) if s3 is not None else w3_fp8.float().to(device)
|
|
|
|
scale_2 = compute_global_scale(w1, w3)
|
|
|
|
packed1, blk1 = quantize_to_nvfp4(w1, scale_2)
|
|
packed3, blk3 = quantize_to_nvfp4(w3, scale_2)
|
|
|
|
writer.add(name_w1, packed1)
|
|
writer.add(name_w1.replace(".weight", ".weight_scale"), blk1)
|
|
writer.add(name_w1.replace(".weight", ".weight_scale_2"), scale_2)
|
|
|
|
writer.add(name_w3, packed3)
|
|
writer.add(name_w3.replace(".weight", ".weight_scale"), blk3)
|
|
writer.add(name_w3.replace(".weight", ".weight_scale_2"), scale_2)
|
|
|
|
# ------------------------------------------------------------------
|
|
# 3. Solo quantize tensors — independent scale_2 per tensor
|
|
# ------------------------------------------------------------------
|
|
for name in tqdm(plan["solo_quantize"], desc="Solo quantize", unit="tensor"):
|
|
w_fp8, s = load_weight_and_scale(cache, weight_map, name)
|
|
with torch.no_grad():
|
|
if s is not None:
|
|
w = dequant_fp8_to_fp32(w_fp8.to(device), s.to(device))
|
|
else:
|
|
# Already non-FP8 (e.g. BF16), just upcast
|
|
w = w_fp8.float().to(device)
|
|
|
|
scale_2 = compute_global_scale(w)
|
|
packed, blk = quantize_to_nvfp4(w, scale_2)
|
|
writer.add(name, packed)
|
|
writer.add(name.replace(".weight", ".weight_scale"), blk)
|
|
writer.add(name.replace(".weight", ".weight_scale_2"), scale_2)
|
|
|
|
# Finalize shards & index
|
|
final_weight_map = writer.close()
|
|
cache.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# 4. Write model.safetensors.index.json
|
|
# ------------------------------------------------------------------
|
|
total_size = sum(
|
|
(dst / fn).stat().st_size for fn in set(final_weight_map.values())
|
|
)
|
|
new_index = {
|
|
"metadata": {"total_size": total_size},
|
|
"weight_map": final_weight_map,
|
|
}
|
|
with open(dst / "model.safetensors.index.json", "w") as f:
|
|
json.dump(new_index, f, indent=2)
|
|
|
|
# ------------------------------------------------------------------
|
|
# 5. Copy non-tensor files (config, tokenizer, etc.)
|
|
# ------------------------------------------------------------------
|
|
for fname in src.iterdir():
|
|
if fname.is_dir():
|
|
# encoding/, inference/, assets/ — copy whole tree
|
|
dst_sub = dst / fname.name
|
|
if not dst_sub.exists():
|
|
shutil.copytree(fname, dst_sub)
|
|
continue
|
|
if fname.suffix == ".safetensors":
|
|
continue
|
|
if fname.name == "model.safetensors.index.json":
|
|
continue
|
|
shutil.copy2(fname, dst / fname.name)
|
|
|
|
# ------------------------------------------------------------------
|
|
# 6. Patch config.json with quantization metadata so loaders know
|
|
# ------------------------------------------------------------------
|
|
cfg_path = dst / "config.json"
|
|
if cfg_path.exists():
|
|
with open(cfg_path) as f:
|
|
cfg = json.load(f)
|
|
cfg["quantization_config"] = {
|
|
"quant_method": "compressed-tensors",
|
|
"format": "nvfp4-pack-quantized",
|
|
"config_groups": {
|
|
"group_0": {
|
|
"targets": ["Linear"],
|
|
"weights": {
|
|
"num_bits": 4,
|
|
"type": "float",
|
|
"strategy": "tensor_group",
|
|
"group_size": 16,
|
|
"symmetric": True,
|
|
},
|
|
}
|
|
},
|
|
"ignore": PRESERVE_REGEXES,
|
|
}
|
|
with open(cfg_path, "w") as f:
|
|
json.dump(cfg, f, indent=2)
|
|
|
|
elapsed = time.time() - t_start
|
|
print(f"\nDone in {elapsed/3600:.2f}h")
|
|
print(f"Output: {dst}")
|
|
print(f"Total size: {total_size/1024**3:.1f} GB across {len(set(final_weight_map.values()))} shards")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |