Files
deepseek-v4-quant/scripts/quantize_nvfp4.py

389 lines
15 KiB
Python

#!/usr/bin/env python3
"""
DeepSeek V4 Pro → NVFP4 quantization — defensive edition.
This script:
1. Applies runtime patches for GPU tensor safety (before modelopt runs)
2. Calls the SAME hf_ptq.py pipeline that the shell script uses
3. After calibration, snapshots amax to CPU and saves model state
The key insight: we don't rewrite the pipeline. We let hf_ptq do its thing
with all its args, defaults, and edge cases handled correctly. We just add
our defensive patches and post-calibration saves.
Must be run from the modelopt example directory:
cd /root/nvidia-meeting/modelopt-repo/examples/llm_ptq
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py
Usage:
# Full run (calibrate + export):
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py
# Re-run export only (after a calibration save exists):
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --export-only
# Validate saved calibration state (check amax values):
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --validate-only
"""
import argparse
import gc
import os
import sys
import time
import torch
# ── Config ──────────────────────────────────────────────────────────────────
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-BF16"
QUANT = "nvfp4"
TP = 8
CALIB_SIZE = 128
CALIB_SEQ = 512
KV_CACHE_QUANT = "fp8_cast"
GPU_MEM_PCT = 0.7
HF_TOKEN = "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"
# Paths
EXAMPLE_DIR = "/root/nvidia-meeting/modelopt-repo/examples/llm_ptq"
EXPORT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
CALIB_SAVE_PATH = "/root/nvidia-meeting/v4_nvfp4_calibrated_state.pt"
AMAX_SNAPSHOT_PATH = "/root/nvidia-meeting/v4_nvfp4_amax_snapshots.pt"
def apply_patches():
"""Apply runtime patches for V4 compatibility and GPU tensor safety.
These patches are applied BEFORE hf_ptq runs, so they're active during
calibration and export. No modelopt source files are modified.
"""
from modelopt.torch.quantization.nn.modules import tensor_quantizer as tq_module
from modelopt.torch.quantization.qtensor import nvfp4_tensor
# ── Patch 1: load_calib_amax — force _amax to CPU after calibration ──
orig_load_calib_amax = tq_module.TensorQuantizer.load_calib_amax
def patched_load_calib_amax(self, *args, **kwargs):
orig_load_calib_amax(self, *args, **kwargs)
if hasattr(self, '_amax') and self._amax is not None:
self._amax = self._amax.cpu()
tq_module.TensorQuantizer.load_calib_amax = patched_load_calib_amax
print("✓ Patched TensorQuantizer.load_calib_amax (force _amax to CPU)")
# ── Patch 2: export_amax — CPU safety ──
orig_export_amax = tq_module.TensorQuantizer.export_amax
def patched_export_amax(self):
if self.amax is not None and self.amax.is_cuda:
self._amax = self._amax.cpu()
return orig_export_amax(self)
tq_module.TensorQuantizer.export_amax = patched_export_amax
print("✓ Patched TensorQuantizer.export_amax (CPU fallback)")
# ── Patch 3: NVFP4QTensor.get_activation_scaling_factor — graceful degradation ──
@classmethod
def patched_get_activation_scaling_factor(cls, quantizer):
if not quantizer.is_enabled:
return None
try:
amax = quantizer.export_amax()
except (torch.cuda.CudaError, RuntimeError) as e:
print(f" WARNING: export_amax() failed ({e}), attempting CPU recovery...")
if hasattr(quantizer, '_amax') and quantizer._amax is not None:
quantizer._amax = quantizer._amax.cpu()
amax = quantizer.export_amax()
if amax is None:
return None
amax = amax.cpu()
activation_scaling_factor = amax.float() / (quantizer.maxbound * 448.0)
if not torch.all(activation_scaling_factor > 0):
n_bad = (activation_scaling_factor <= 0).sum().item()
n_total = activation_scaling_factor.numel()
print(f" WARNING: {n_bad}/{n_total} activation scaling factors <= 0, clamping")
activation_scaling_factor = activation_scaling_factor.clamp(min=torch.finfo(torch.float32).tiny)
return activation_scaling_factor
nvfp4_tensor.NVFP4QTensor.get_activation_scaling_factor = patched_get_activation_scaling_factor
print("✓ Patched NVFP4QTensor.get_activation_scaling_factor (CPU + clamp)")
def snapshot_amax_to_cpu(model, snapshot_path):
"""Walk all quantizers, copy _amax to CPU, save to disk."""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
print(f"\nSnapshotting quantizer _amax to CPU...")
t0 = time.time()
snapshots = {}
n_moved = 0
for name, module in model.named_modules():
if not isinstance(module, TensorQuantizer):
continue
if hasattr(module, '_amax') and module._amax is not None:
amax_cpu = module._amax.detach().cpu().clone()
snapshots[name] = amax_cpu
module._amax.data.copy_(amax_cpu)
n_moved += 1
torch.save(snapshots, snapshot_path)
size_mb = os.path.getsize(snapshot_path) / (1024**2)
print(f"✓ Snapshotted {n_moved} quantizer _amax tensors to CPU ({time.time()-t0:.1f}s)")
print(f" Saved to: {snapshot_path} ({size_mb:.1f} MB)")
return snapshots
def restore_amax_from_snapshot(model, snapshot_path):
"""Restore _amax from a previously saved CPU snapshot."""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
print(f"Restoring _amax from snapshot: {snapshot_path}")
snapshots = torch.load(snapshot_path, map_location='cpu')
n_restored = 0
for name, module in model.named_modules():
if not isinstance(module, TensorQuantizer):
continue
if name in snapshots and hasattr(module, '_amax'):
module._amax.data.copy_(snapshots[name].to(module._amax.device))
n_restored += 1
print(f"✓ Restored {n_restored} _amax tensors from snapshot")
def force_all_amax_to_cpu(model):
"""Force ALL quantizer tensors to CPU."""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
count = 0
for name, module in model.named_modules():
if not isinstance(module, TensorQuantizer):
continue
for attr in ['_amax', '_pre_quant_scale', '_global_amax']:
if hasattr(module, attr):
val = getattr(module, attr)
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
setattr(module, attr, val.cpu())
count += 1
print(f"✓ Forced {count} quantizer tensors to CPU")
def save_calibrated_state(model, path):
"""Save model state dict after calibration."""
print(f"\n{'='*60}")
print(f"SAVING CALIBRATED STATE → {path}")
print(f"{'='*60}")
start = time.time()
state = {
'model_state_dict': model.state_dict(),
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
}
torch.save(state, path)
size_gb = os.path.getsize(path) / (1024**3)
print(f"✓ Saved calibrated state: {size_gb:.1f} GB ({time.time()-start:.0f}s)")
print(f" Path: {path}")
print(f" Re-run with --export-only to retry export.\n")
def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path, calib_size, calib_seq):
"""Full pipeline: parse args via hf_ptq → load → quantize → snapshot → save → export."""
os.chdir(EXAMPLE_DIR)
sys.path.insert(0, EXAMPLE_DIR)
os.environ["HF_TOKEN"] = HF_TOKEN
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
from hf_ptq import parse_args, main as hf_main
apply_patches()
# ── Build args using hf_ptq's own parser ──
# This guarantees ALL attributes exist with correct defaults.
# We temporarily replace sys.argv so parse_args() sees our config.
saved_argv = sys.argv
sys.argv = [
"hf_ptq.py",
"--model", model_path,
"--quant", QUANT,
"--calib_size", str(calib_size),
"--calib_seq", str(calib_seq),
"--kv_cache_qformat", KV_CACHE_QUANT,
"--tp", str(TP),
"--export_path", export_dir,
"--trust_remote_code",
"--use_seq_device_map",
"--gpu_max_mem_percentage", str(GPU_MEM_PCT),
"--batch_size", "0",
]
args = parse_args()
sys.argv = saved_argv
# ── Post-calibration hook ──
# We monkey-patch export_quantized to add our defensive saves before export.
import hf_ptq
orig_export_quantized = hf_ptq.export_quantized
def patched_export_quantized(exp_args, full_model, language_model, model_type,
tokenizer, default_padding_side, default_pad_token):
"""Wrapper that snapshots amax and saves state before calling the real export."""
print("\n" + "="*60)
print("POST-CALIBRATION: Snapshotting amax and saving state")
print("="*60)
# Snapshot amax to CPU
snapshot_amax_to_cpu(language_model, amax_snapshot_path)
# Force all quantizer state to CPU
force_all_amax_to_cpu(language_model)
# Free GPU memory
torch.cuda.empty_cache()
gc.collect()
# Save calibrated state
save_calibrated_state(language_model, calib_save_path)
# Now run the real export
orig_export_quantized(exp_args, full_model, language_model, model_type,
tokenizer, default_padding_side, default_pad_token)
hf_ptq.export_quantized = patched_export_quantized
print("✓ Hooked export_quantized with amax snapshot + state save")
# ── Run hf_ptq's full pipeline ──
# This handles model loading, quantization, calibration, and export
# using the exact same code path as the shell script.
hf_main(args)
def run_export_only(calib_save_path, amax_snapshot_path, model_path, export_dir):
"""Load saved calibration state and run export only."""
os.chdir(EXAMPLE_DIR)
sys.path.insert(0, EXAMPLE_DIR)
os.environ["HF_TOKEN"] = HF_TOKEN
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
apply_patches()
from example_utils import get_model, get_tokenizer
print(f"Loading model from {model_path}...")
model = get_model(
model_path,
device="cpu",
trust_remote_code=True,
)
tokenizer = get_tokenizer(model_path, trust_remote_code=True)
print(f"Loading calibrated state from {calib_save_path}...")
state = torch.load(calib_save_path, map_location='cpu')
model.load_state_dict(state['model_state_dict'])
print(f"✓ Loaded calibrated state (saved at {state['timestamp']})")
force_all_amax_to_cpu(model)
if amax_snapshot_path and os.path.exists(amax_snapshot_path):
restore_amax_from_snapshot(model, amax_snapshot_path)
torch.cuda.empty_cache()
gc.collect()
from modelopt.torch.export import export_hf_checkpoint
from hf_ptq import load_mtp_weights, copy_custom_model_files
print(f"\n{'='*60}")
print(f"EXPORTING → {export_dir}")
print(f"{'='*60}")
t0 = time.time()
try:
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(model, model_path)
if mtp_layer_prefixes:
model._mtp_layer_prefixes = mtp_layer_prefixes
export_hf_checkpoint(model, export_dir=export_dir, extra_state_dict=mtp_state_dict)
tokenizer.save_pretrained(export_dir)
copy_custom_model_files(model_path, export_dir, True)
print(f"\n✓ Export complete in {time.time()-t0:.0f}s → {export_dir}")
except Exception as e:
print(f"\n✗ EXPORT FAILED: {e}")
print(f" Calibrated state: {CALIB_SAVE_PATH}")
print(f" Amax snapshots: {AMAX_SNAPSHOT_PATH}")
raise
def run_validate(calib_save_path, amax_snapshot_path):
"""Validate saved calibration state — check amax values are valid."""
print(f"\nValidating calibration state...")
if os.path.exists(amax_snapshot_path):
snapshots = torch.load(amax_snapshot_path, map_location='cpu')
n_total = len(snapshots)
n_valid = n_zero = n_nan = n_neg = 0
for name, amax in snapshots.items():
if torch.any(torch.isnan(amax)):
n_nan += 1
elif torch.any(amax < 0):
n_neg += 1
elif torch.all(amax == 0):
n_zero += 1
else:
n_valid += 1
print(f"\nAmax snapshot validation:")
print(f" Total: {n_total} Valid: {n_valid} Zero: {n_zero} Neg: {n_neg} NaN: {n_nan}")
if n_valid == n_total:
print(f"\n✓ All {n_total} amax snapshots are valid!")
else:
print(f"\n{n_total - n_valid} quantizers have invalid amax!")
else:
print(f"✗ No amax snapshot found at {amax_snapshot_path}")
if os.path.exists(calib_save_path):
size_gb = os.path.getsize(calib_save_path) / (1024**3)
print(f"\nCalibrated state: {calib_save_path} ({size_gb:.1f} GB)")
else:
print(f"\n✗ No calibrated state found at {calib_save_path}")
def main():
parser = argparse.ArgumentParser(description="DeepSeek V4 Pro NVFP4 Quantization")
parser.add_argument("--export-only", action="store_true",
help="Skip calibration, load saved state and run export only")
parser.add_argument("--validate-only", action="store_true",
help="Validate saved calibration state without running anything")
parser.add_argument("--model", default=MODEL, help="Path to BF16 model")
parser.add_argument("--export-dir", default=EXPORT_DIR, help="Export output directory")
parser.add_argument("--calib-save", default=CALIB_SAVE_PATH, help="Calibration state save path")
parser.add_argument("--amax-snapshot", default=AMAX_SNAPSHOT_PATH, help="Amax snapshot path")
parser.add_argument("--calib-size", type=int, default=CALIB_SIZE, help="Calibration samples")
parser.add_argument("--calib-seq", type=int, default=CALIB_SEQ, help="Calibration sequence length")
args = parser.parse_args()
if args.validate_only:
run_validate(args.calib_save, args.amax_snapshot)
elif args.export_only:
if not os.path.exists(args.calib_save):
print(f"ERROR: No calibration state found at {args.calib_save}")
sys.exit(1)
run_export_only(args.calib_save, args.amax_snapshot, args.model, args.export_dir)
else:
run_calibration(args.model, args.export_dir, args.calib_save,
args.amax_snapshot, args.calib_size, args.calib_seq)
if __name__ == "__main__":
main()