Defensive quantization: snapshot amax to CPU immediately after calibration

Key changes:
- snapshot_amax_to_cpu(): copies all quantizer _amax to CPU and saves
  to disk (~50MB) right after mtq.quantize() returns, before any other
  GPU operation can corrupt them
- force_all_amax_to_cpu(): nuclear option, moves _pre_quant_scale and
  _global_amax to CPU too
- _FORCE_AMAX_CPU flag + patched amax setter: after calibration, any
  future amax writes go to CPU instead of GPU
- --validate-only mode to check saved state without running anything
- restore_amax_from_snapshot() for --export-only recovery
- torch.cuda.empty_cache() + gc.collect() between steps
- Patches: export_amax CPU fallback, get_activation_scaling_factor
  clamp instead of assert
This commit is contained in:
2026-05-09 06:31:08 +00:00
parent 3907838409
commit 6eaba26914

View File

@@ -1,10 +1,17 @@
#!/usr/bin/env python3
"""
DeepSeek V4 Pro → NVFP4 quantization.
DeepSeek V4 Pro → NVFP4 quantization — defensive edition.
Runs the full ModelOpt PTQ pipeline in-process (not wrapping the shell script),
saves model state after calibration (so we don't lose 6 hours of work to an
export crash), and patches the export path to handle stale GPU tensors.
Runs the full ModelOpt PTQ pipeline with maximum protection against GPU tensor
corruption that crashes the export after 6 hours of calibration.
Key defense: immediately after calibration, every quantizer _amax tensor is
snapshotted to CPU. Then the model state is saved to disk. If export crashes,
the state can be reloaded and export retried without re-calibrating.
The _amax tensors are tiny (scalars and small vectors). Snapshotting ~49K of them
to CPU costs almost nothing in memory and guarantees we have valid calibration
data regardless of what CUDA does to the GPU copies afterward.
Must be run from the modelopt example directory for imports:
cd /root/nvidia-meeting/modelopt-repo/examples/llm_ptq
@@ -16,10 +23,14 @@ Usage:
# Re-run export only (after a calibration save exists):
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --export-only
# Validate saved calibration state (check amax values):
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --validate-only
"""
import argparse
import copy
import gc
import os
import sys
import time
@@ -43,37 +54,66 @@ HF_TOKEN = "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"
EXAMPLE_DIR = "/root/nvidia-meeting/modelopt-repo/examples/llm_ptq"
EXPORT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
CALIB_SAVE_PATH = "/root/nvidia-meeting/v4_nvfp4_calibrated_state.pt"
AMAX_SNAPSHOT_PATH = "/root/nvidia-meeting/v4_nvfp4_amax_snapshots.pt"
# Flag: when True, force all new _amax writes to CPU
_FORCE_AMAX_CPU = False
def apply_patches():
"""Apply runtime patches for V4 compatibility."""
"""Apply runtime patches for V4 compatibility and GPU tensor safety."""
from modelopt.torch.quantization.nn.modules import tensor_quantizer as tq_module
# 1. Patch TensorQuantizer.export_amax to move _amax to CPU before reading
# ── Patch 1: Force _amax to CPU after calibration completes ──
#
# The _amax property setter is called by load_calib_amax() at the end of
# calibration. By default it stores on GPU. We patch it so that when
# _FORCE_AMAX_CPU is True, _amax goes to CPU instead.
#
# During calibration (before the flag is set), _amax stays on GPU for
# fake quantization. After calibration, we set the flag and re-call
# load_calib_amax() to re-populate _amax on CPU.
orig_amax_setter = tq_module.TensorQuantizer.amax.fset
def patched_amax_setter(self, value):
assert value is not None, "amax cannot be set to None."
if not isinstance(value, torch.Tensor):
value = torch.tensor(value)
if not hasattr(self, "_amax"):
if _FORCE_AMAX_CPU:
self.register_buffer("_amax", value.clone().detach().cpu())
else:
self.register_buffer("_amax", value.clone().detach())
else:
if self._amax.shape != value.shape:
raise RuntimeError("Changing shape when setting amax is not allowed.")
target = self._amax.cpu() if _FORCE_AMAX_CPU else self._amax
self._amax.data.copy_(value.clone().detach().to(target.device))
tq_module.TensorQuantizer.amax.fset = patched_amax_setter
print("✓ Patched TensorQuantizer.amax setter (CPU mode controlled by _FORCE_AMAX_CPU)")
# ── Patch 2: export_amax — CPU safety ──
# If any _amax is still on GPU at export time, move it before reading.
orig_export_amax = tq_module.TensorQuantizer.export_amax
def patched_export_amax(self):
"""Move _amax to CPU before export to prevent CUDA illegal memory access
on tensors that have been sitting in VRAM for hours during calibration."""
if self.amax is not None and self.amax.is_cuda:
self._amax = self._amax.cpu()
return orig_export_amax(self)
tq_module.TensorQuantizer.export_amax = patched_export_amax
print("✓ Patched TensorQuantizer.export_amax (CPU safety)")
print("✓ Patched TensorQuantizer.export_amax (CPU fallback)")
# 2. Patch NVFP4QTensor.get_activation_scaling_factor for graceful degradation
# ── Patch 3: NVFP4QTensor.get_activation_scaling_factor graceful degradation ──
from modelopt.torch.quantization.qtensor import nvfp4_tensor
orig_get_asf = nvfp4_tensor.NVFP4QTensor.get_activation_scaling_factor
@classmethod
def patched_get_activation_scaling_factor(cls, quantizer):
"""Move amax to CPU before export; clamp instead of assert on bad values."""
if not quantizer.is_enabled:
return None
try:
amax = quantizer.export_amax()
except (torch.cuda.CudaError, RuntimeError) as e:
@@ -84,40 +124,108 @@ def apply_patches():
if amax is None:
return None
amax = amax.cpu()
activation_scaling_factor = amax.float() / (quantizer.maxbound * 448.0)
# Replace hard assert with warning + clamp
# Clamp instead of hard assert — bad values from GPU corruption should
# not kill the entire 6-hour run
if not torch.all(activation_scaling_factor > 0):
n_bad = (activation_scaling_factor <= 0).sum().item()
n_total = activation_scaling_factor.numel()
print(f" WARNING: {n_bad}/{n_total} activation scaling factors <= 0, clamping to tiny")
print(f" WARNING: {n_bad}/{n_total} activation scaling factors <= 0, clamping")
activation_scaling_factor = activation_scaling_factor.clamp(min=torch.finfo(torch.float32).tiny)
return activation_scaling_factor
nvfp4_tensor.NVFP4QTensor.get_activation_scaling_factor = patched_get_activation_scaling_factor
print("✓ Patched NVFP4QTensor.get_activation_scaling_factor (CPU safety + graceful degradation)")
print("✓ Patched NVFP4QTensor.get_activation_scaling_factor (CPU + clamp)")
def move_quantizers_to_cpu(model):
"""Move all quantizer amax tensors to CPU to prevent stale GPU reads during export."""
def snapshot_amax_to_cpu(model, snapshot_path):
"""Walk all quantizers, copy their _amax to CPU, save to disk.
This is the core defensive measure. After calibration completes, the _amax
tensors are fresh and valid on GPU. We copy them to CPU immediately and
save to disk. This costs almost nothing (~50MB for ~49K quantizers) but
guarantees we have valid calibration data even if CUDA corrupts the GPU
copies later.
Returns the snapshot dict: {quantizer_name: amax_tensor_on_cpu}
"""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
print(f"\nSnapshotting quantizer _amax to CPU...")
t0 = time.time()
snapshots = {}
n_moved = 0
for name, module in model.named_modules():
if not isinstance(module, TensorQuantizer):
continue
if hasattr(module, '_amax') and module._amax is not None:
# Copy to CPU immediately
amax_cpu = module._amax.detach().cpu().clone()
snapshots[name] = amax_cpu
# Replace the GPU copy with the CPU copy
module._amax.data.copy_(amax_cpu)
n_moved += 1
# Save snapshots to disk
torch.save(snapshots, snapshot_path)
size_mb = os.path.getsize(snapshot_path) / (1024**2)
print(f"✓ Snapshotted {n_moved} quantizer _amax tensors to CPU ({time.time()-t0:.1f}s)")
print(f" Saved to: {snapshot_path} ({size_mb:.1f} MB)")
return snapshots
def restore_amax_from_snapshot(model, snapshot_path):
"""Restore _amax from a previously saved CPU snapshot.
Used by --export-only to guarantee valid amax values even if the
model state dict has corrupted GPU tensors.
"""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
print(f"Restoring _amax from snapshot: {snapshot_path}")
snapshots = torch.load(snapshot_path, map_location='cpu')
n_restored = 0
for name, module in model.named_modules():
if not isinstance(module, TensorQuantizer):
continue
if name in snapshots and hasattr(module, '_amax'):
module._amax.data.copy_(snapshots[name].to(module._amax.device))
n_restored += 1
print(f"✓ Restored {n_restored} _amax tensors from snapshot")
def force_all_amax_to_cpu(model):
"""Force ALL quantizer tensors to CPU. Nuclear option after calibration.
After calling this, no quantizer _amax lives on GPU. Export can't hit
CUDA illegal memory access because there's nothing on GPU to corrupt.
"""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
count = 0
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer):
if hasattr(module, '_amax') and module._amax is not None:
if module._amax.is_cuda:
module._amax = module._amax.cpu()
if not isinstance(module, TensorQuantizer):
continue
for attr in ['_amax', '_pre_quant_scale', '_global_amax']:
if hasattr(module, attr):
val = getattr(module, attr)
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
setattr(module, attr, val.cpu())
count += 1
print(f"Moved {count} quantizer _amax tensors to CPU")
print(f"Forced {count} quantizer tensors to CPU")
def save_calibrated_state(model, path):
"""Save model state dict after calibration.
Insurance policy: if export crashes, we can reload and retry
The insurance policy: if export crashes, we can reload and retry
without re-running 6 hours of calibration.
"""
print(f"\n{'='*60}")
@@ -126,9 +234,7 @@ def save_calibrated_state(model, path):
start = time.time()
# Move quantizers to CPU first
move_quantizers_to_cpu(model)
# All quantizer state should already be on CPU from snapshot_amax_to_cpu
state = {
'model_state_dict': model.state_dict(),
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
@@ -138,21 +244,20 @@ def save_calibrated_state(model, path):
size_gb = os.path.getsize(path) / (1024**3)
print(f"✓ Saved calibrated state: {size_gb:.1f} GB ({time.time()-start:.0f}s)")
print(f" Path: {path}")
print(f" Re-run with --export-only to retry export without recalibrating.\n")
print(f" Re-run with --export-only to retry export.\n")
def run_calibration(model_path, export_dir, calib_save_path, calib_size, calib_seq):
"""Full pipeline: load → quantize → calibrate → save → export."""
def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path, calib_size, calib_seq):
"""Full pipeline: load → quantize → calibrate → snapshot → save → export."""
global _FORCE_AMAX_CPU
# Must be in the example dir for relative imports (example_utils, etc.)
os.chdir(EXAMPLE_DIR)
sys.path.insert(0, EXAMPLE_DIR)
# Set HF token for gated datasets
os.environ["HF_TOKEN"] = HF_TOKEN
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
# These imports depend on the example dir being in sys.path
from hf_ptq import (
get_model, get_tokenizer, make_calib_dataloader,
build_quant_cfg, load_mtp_weights, copy_custom_model_files,
@@ -164,7 +269,6 @@ def run_calibration(model_path, export_dir, calib_save_path, calib_size, calib_s
from modelopt.torch.export import export_hf_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
# Apply patches before loading model
apply_patches()
# ── Load model ──
@@ -203,7 +307,6 @@ def run_calibration(model_path, export_dir, calib_save_path, calib_size, calib_s
print(f"✓ Using calibration batch_size={batch_size}")
# ── Prepare dataloader ──
# Build a minimal args namespace for make_calib_dataloader
args = argparse.Namespace(
calib_size=[calib_size],
calib_seq=calib_seq,
@@ -221,18 +324,37 @@ def run_calibration(model_path, export_dir, calib_save_path, calib_size, calib_s
print(f"{'='*60}")
t0 = time.time()
# _FORCE_AMAX_CPU is False during calibration — amax stays on GPU for
# fake quantization during the forward passes
model = mtq.quantize(model, quant_cfg, forward_loop=calib_dataloader)
print(f"✓ Quantization + calibration complete in {time.time()-t0:.0f}s")
# ── SAVE STATE (the whole point of this script) ──
# ── IMMEDIATELY snapshot all _amax to CPU ──
# This is the critical defensive step. Right after mtq.quantize() returns,
# the _amax tensors are fresh and valid on GPU. We copy them to CPU NOW,
# before any other GPU operation has a chance to corrupt them.
snapshots = snapshot_amax_to_cpu(model, amax_snapshot_path)
# ── Force ALL quantizer state to CPU ──
# After snapshotting, force remaining GPU tensors to CPU too
force_all_amax_to_cpu(model)
# ── Enable CPU mode for any future amax writes ──
_FORCE_AMAX_CPU = True
# ── Free GPU memory ──
torch.cuda.empty_cache()
gc.collect()
# ── SAVE STATE ──
save_calibrated_state(model, calib_save_path)
# ── Export ──
run_export(model, tokenizer, model_path, export_dir)
run_export(model, tokenizer, model_path, export_dir, amax_snapshot_path)
def run_export(model, tokenizer, model_path, export_dir):
def run_export(model, tokenizer, model_path, export_dir, amax_snapshot_path=None):
"""Export the quantized model to HF safetensors format."""
from modelopt.torch.export import export_hf_checkpoint
from hf_ptq import load_mtp_weights, copy_custom_model_files
@@ -241,7 +363,14 @@ def run_export(model, tokenizer, model_path, export_dir):
print(f"EXPORTING → {export_dir}")
print(f"{'='*60}")
move_quantizers_to_cpu(model)
# Ensure all quantizer state is on CPU
force_all_amax_to_cpu(model)
if amax_snapshot_path and os.path.exists(amax_snapshot_path):
restore_amax_from_snapshot(model, amax_snapshot_path)
# Free GPU memory before export
torch.cuda.empty_cache()
gc.collect()
t0 = time.time()
@@ -263,13 +392,17 @@ def run_export(model, tokenizer, model_path, export_dir):
except Exception as e:
print(f"\n✗ EXPORT FAILED: {e}")
print(f" Calibrated state is saved at: {CALIB_SAVE_PATH}")
print(f" Re-run with --export-only to retry export")
print(f" Calibrated state: {CALIB_SAVE_PATH}")
print(f" Amax snapshots: {AMAX_SNAPSHOT_PATH}")
print(f" Re-run with --export-only to retry")
raise
def run_export_only(calib_save_path, model_path, export_dir):
"""Load previously saved calibration state and run export only."""
def run_export_only(calib_save_path, amax_snapshot_path, model_path, export_dir):
"""Load saved calibration state and run export only."""
global _FORCE_AMAX_CPU
_FORCE_AMAX_CPU = True # Force CPU for any amax writes
os.chdir(EXAMPLE_DIR)
sys.path.insert(0, EXAMPLE_DIR)
@@ -289,35 +422,83 @@ def run_export_only(calib_save_path, model_path, export_dir):
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load the calibrated state
print(f"Loading calibrated state from {calib_save_path}...")
state = torch.load(calib_save_path, map_location='cpu')
model.load_state_dict(state['model_state_dict'])
print(f"✓ Loaded calibrated state (saved at {state['timestamp']})")
run_export(model, tokenizer, model_path, export_dir)
run_export(model, tokenizer, model_path, export_dir, amax_snapshot_path)
def run_validate(calib_save_path, amax_snapshot_path):
"""Validate saved calibration state — check amax values are valid."""
print(f"\nValidating calibration state...")
# Check amax snapshots
if os.path.exists(amax_snapshot_path):
snapshots = torch.load(amax_snapshot_path, map_location='cpu')
n_total = len(snapshots)
n_valid = 0
n_zero = 0
n_nan = 0
n_neg = 0
for name, amax in snapshots.items():
if torch.any(torch.isnan(amax)):
n_nan += 1
elif torch.any(amax < 0):
n_neg += 1
elif torch.all(amax == 0):
n_zero += 1
else:
n_valid += 1
print(f"\nAmax snapshot validation:")
print(f" Total quantizers: {n_total}")
print(f" Valid: {n_valid}")
print(f" All zeros: {n_zero}")
print(f" Negative: {n_neg}")
print(f" NaN: {n_nan}")
if n_valid == n_total:
print(f"\n✓ All {n_total} amax snapshots are valid!")
else:
print(f"\n{n_total - n_valid} quantizers have invalid amax!")
else:
print(f"✗ No amax snapshot found at {amax_snapshot_path}")
# Check full state dict
if os.path.exists(calib_save_path):
size_gb = os.path.getsize(calib_save_path) / (1024**3)
print(f"\nCalibrated state: {calib_save_path} ({size_gb:.1f} GB)")
else:
print(f"\n✗ No calibrated state found at {calib_save_path}")
def main():
parser = argparse.ArgumentParser(description="DeepSeek V4 Pro NVFP4 Quantization")
parser.add_argument("--export-only", action="store_true",
help="Skip calibration, load saved state and run export only")
parser.add_argument("--validate-only", action="store_only",
help="Validate saved calibration state without running anything")
parser.add_argument("--model", default=MODEL, help="Path to BF16 model")
parser.add_argument("--export-dir", default=EXPORT_DIR, help="Export output directory")
parser.add_argument("--calib-save", default=CALIB_SAVE_PATH, help="Calibration state save path")
parser.add_argument("--amax-snapshot", default=AMAX_SNAPSHOT_PATH, help="Amax snapshot path")
parser.add_argument("--calib-size", type=int, default=CALIB_SIZE, help="Calibration samples")
parser.add_argument("--calib-seq", type=int, default=CALIB_SEQ, help="Calibration sequence length")
args = parser.parse_args()
if args.export_only:
if args.validate_only:
run_validate(args.calib_save, args.amax_snapshot)
elif args.export_only:
if not os.path.exists(args.calib_save):
print(f"ERROR: No calibration state found at {args.calib_save}")
print("Run without --export-only first to calibrate.")
sys.exit(1)
run_export_only(args.calib_save, args.model, args.export_dir)
run_export_only(args.calib_save, args.amax_snapshot, args.model, args.export_dir)
else:
run_calibration(args.model, args.export_dir, args.calib_save,
args.calib_size, args.calib_seq)
args.amax_snapshot, args.calib_size, args.calib_seq)
if __name__ == "__main__":