2026-05-09 06:07:22 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
"""
|
2026-05-09 06:31:08 +00:00
|
|
|
DeepSeek V4 Pro → NVFP4 quantization — defensive edition.
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
This script:
|
|
|
|
|
1. Applies runtime patches for GPU tensor safety (before modelopt runs)
|
|
|
|
|
2. Calls the SAME hf_ptq.py pipeline that the shell script uses
|
|
|
|
|
3. After calibration, snapshots amax to CPU and saves model state
|
2026-05-09 06:31:08 +00:00
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
The key insight: we don't rewrite the pipeline. We let hf_ptq do its thing
|
|
|
|
|
with all its args, defaults, and edge cases handled correctly. We just add
|
|
|
|
|
our defensive patches and post-calibration saves.
|
2026-05-09 06:31:08 +00:00
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
Must be run from the modelopt example directory:
|
2026-05-09 06:08:35 +00:00
|
|
|
cd /root/nvidia-meeting/modelopt-repo/examples/llm_ptq
|
|
|
|
|
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py
|
|
|
|
|
|
2026-05-09 06:07:22 +00:00
|
|
|
Usage:
|
|
|
|
|
# Full run (calibrate + export):
|
2026-05-09 06:08:35 +00:00
|
|
|
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py
|
2026-05-09 06:07:22 +00:00
|
|
|
|
|
|
|
|
# Re-run export only (after a calibration save exists):
|
2026-05-09 06:08:35 +00:00
|
|
|
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --export-only
|
2026-05-09 06:31:08 +00:00
|
|
|
|
|
|
|
|
# Validate saved calibration state (check amax values):
|
|
|
|
|
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --validate-only
|
2026-05-09 06:07:22 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import argparse
|
2026-05-09 06:31:08 +00:00
|
|
|
import gc
|
2026-05-09 06:07:22 +00:00
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
# ── Config ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-BF16"
|
|
|
|
|
QUANT = "nvfp4"
|
|
|
|
|
TP = 8
|
|
|
|
|
CALIB_SIZE = 128
|
|
|
|
|
CALIB_SEQ = 512
|
|
|
|
|
KV_CACHE_QUANT = "fp8_cast"
|
|
|
|
|
GPU_MEM_PCT = 0.7
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = "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"
|
|
|
|
|
|
2026-05-09 06:08:35 +00:00
|
|
|
# Paths
|
|
|
|
|
EXAMPLE_DIR = "/root/nvidia-meeting/modelopt-repo/examples/llm_ptq"
|
2026-05-09 06:07:22 +00:00
|
|
|
EXPORT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
|
|
|
CALIB_SAVE_PATH = "/root/nvidia-meeting/v4_nvfp4_calibrated_state.pt"
|
2026-05-09 06:31:08 +00:00
|
|
|
AMAX_SNAPSHOT_PATH = "/root/nvidia-meeting/v4_nvfp4_amax_snapshots.pt"
|
|
|
|
|
|
2026-05-09 06:07:22 +00:00
|
|
|
|
|
|
|
|
def apply_patches():
|
2026-05-09 14:52:02 +00:00
|
|
|
"""Apply runtime patches for V4 compatibility and GPU tensor safety.
|
|
|
|
|
|
2026-05-09 22:50:58 +00:00
|
|
|
Root cause of all export crashes: use_seq_device_map keeps model weights on GPU
|
|
|
|
|
for 5+ hours during calibration. By export time, CUDA's memory allocator has
|
|
|
|
|
recycled the underlying memory, so any read of those GPU tensors triggers
|
|
|
|
|
cudaErrorIllegalAddress.
|
|
|
|
|
|
|
|
|
|
Fix strategy: patch at the EARLIEST possible entry points to force stale GPU
|
|
|
|
|
tensors to CPU before any downstream code reads them. This covers the full
|
|
|
|
|
chain of execution we traced through the export path:
|
|
|
|
|
|
|
|
|
|
_process_quantized_modules
|
|
|
|
|
→ _export_quantized_weight (or _export_fused_experts)
|
|
|
|
|
→ get_weight_scaling_factor
|
|
|
|
|
→ get_weights_scaling_factor_from_quantizer (reads weight, _amax, global_amax)
|
|
|
|
|
→ NVFP4QTensor.get_weights_scaling_factor (dynamic: reduce_block_amax on weight)
|
|
|
|
|
→ get_weight_scaling_factor_2 (reads _amax, global_amax)
|
|
|
|
|
→ get_activation_scaling_factor (reads _amax) [already patched]
|
|
|
|
|
→ to_quantized_weight (reads weight, does .to(weight.device) on scaling factors)
|
|
|
|
|
→ weight.to(dtype) (reads weight)
|
|
|
|
|
|
|
|
|
|
By forcing weight to CPU in Patch 4 (_export_quantized_weight), ALL downstream
|
|
|
|
|
.to(weight.device) calls resolve to CPU. Patches 5-8 are belt-and-suspenders.
|
2026-05-09 14:52:02 +00:00
|
|
|
"""
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 06:10:18 +00:00
|
|
|
from modelopt.torch.quantization.nn.modules import tensor_quantizer as tq_module
|
2026-05-09 09:26:23 +00:00
|
|
|
from modelopt.torch.quantization.qtensor import nvfp4_tensor
|
2026-05-09 22:50:58 +00:00
|
|
|
from modelopt.torch.export import quant_utils
|
|
|
|
|
from modelopt.torch.quantization.utils import quantizer_attr_names as _quantizer_attr_names
|
|
|
|
|
import modelopt.torch.export.unified_export_hf as uehf
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 22:50:58 +00:00
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
# Patch 1: load_calib_amax — force _amax to CPU immediately after calibration
|
|
|
|
|
# This runs during calibration, right after each quantizer finishes.
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
2026-05-09 08:04:03 +00:00
|
|
|
orig_load_calib_amax = tq_module.TensorQuantizer.load_calib_amax
|
|
|
|
|
|
|
|
|
|
def patched_load_calib_amax(self, *args, **kwargs):
|
|
|
|
|
orig_load_calib_amax(self, *args, **kwargs)
|
|
|
|
|
if hasattr(self, '_amax') and self._amax is not None:
|
|
|
|
|
self._amax = self._amax.cpu()
|
|
|
|
|
|
|
|
|
|
tq_module.TensorQuantizer.load_calib_amax = patched_load_calib_amax
|
2026-05-09 22:50:58 +00:00
|
|
|
print("✓ Patch 1: TensorQuantizer.load_calib_amax → force _amax to CPU")
|
2026-05-09 06:31:08 +00:00
|
|
|
|
2026-05-09 22:50:58 +00:00
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
# Patch 2: export_amax — CPU safety net at export time
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
2026-05-09 06:08:35 +00:00
|
|
|
orig_export_amax = tq_module.TensorQuantizer.export_amax
|
|
|
|
|
|
|
|
|
|
def patched_export_amax(self):
|
2026-05-09 22:50:58 +00:00
|
|
|
if hasattr(self, '_amax') and self._amax is not None and self._amax.is_cuda:
|
2026-05-09 06:08:35 +00:00
|
|
|
self._amax = self._amax.cpu()
|
|
|
|
|
return orig_export_amax(self)
|
|
|
|
|
|
|
|
|
|
tq_module.TensorQuantizer.export_amax = patched_export_amax
|
2026-05-09 22:50:58 +00:00
|
|
|
print("✓ Patch 2: TensorQuantizer.export_amax → CPU fallback")
|
2026-05-09 06:08:35 +00:00
|
|
|
|
2026-05-09 22:50:58 +00:00
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
# Patch 3: get_activation_scaling_factor — CPU + clamp
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
2026-05-09 06:07:22 +00:00
|
|
|
@classmethod
|
|
|
|
|
def patched_get_activation_scaling_factor(cls, quantizer):
|
|
|
|
|
if not quantizer.is_enabled:
|
|
|
|
|
return None
|
|
|
|
|
try:
|
|
|
|
|
amax = quantizer.export_amax()
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError) as e:
|
|
|
|
|
print(f" WARNING: export_amax() failed ({e}), attempting CPU recovery...")
|
|
|
|
|
if hasattr(quantizer, '_amax') and quantizer._amax is not None:
|
|
|
|
|
quantizer._amax = quantizer._amax.cpu()
|
|
|
|
|
amax = quantizer.export_amax()
|
|
|
|
|
|
|
|
|
|
if amax is None:
|
|
|
|
|
return None
|
|
|
|
|
amax = amax.cpu()
|
|
|
|
|
activation_scaling_factor = amax.float() / (quantizer.maxbound * 448.0)
|
|
|
|
|
|
|
|
|
|
if not torch.all(activation_scaling_factor > 0):
|
|
|
|
|
n_bad = (activation_scaling_factor <= 0).sum().item()
|
|
|
|
|
n_total = activation_scaling_factor.numel()
|
2026-05-09 06:31:08 +00:00
|
|
|
print(f" WARNING: {n_bad}/{n_total} activation scaling factors <= 0, clamping")
|
2026-05-09 06:07:22 +00:00
|
|
|
activation_scaling_factor = activation_scaling_factor.clamp(min=torch.finfo(torch.float32).tiny)
|
|
|
|
|
|
|
|
|
|
return activation_scaling_factor
|
|
|
|
|
|
|
|
|
|
nvfp4_tensor.NVFP4QTensor.get_activation_scaling_factor = patched_get_activation_scaling_factor
|
2026-05-09 22:50:58 +00:00
|
|
|
print("✓ Patch 3: NVFP4QTensor.get_activation_scaling_factor → CPU + clamp")
|
|
|
|
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
# Patch 4: _export_quantized_weight — THE KEY PATCH
|
|
|
|
|
#
|
|
|
|
|
# This is the entry point for exporting each quantized module. It reads
|
|
|
|
|
# `weight = getattr(sub_module, weight_name)` which is on a stale GPU.
|
|
|
|
|
# By moving weight to CPU right here, ALL downstream functions are safe:
|
|
|
|
|
# - get_weight_scaling_factor: weight.device is now CPU
|
|
|
|
|
# - get_weights_scaling_factor: operates on CPU weight
|
|
|
|
|
# - to_quantized_weight: .to(weight.device) stays on CPU
|
|
|
|
|
# - weight.to(dtype): CPU cast
|
|
|
|
|
# We also force all quantizer state to CPU for the same reason.
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
orig_export_quantized_weight = uehf._export_quantized_weight
|
|
|
|
|
|
|
|
|
|
def patched_export_quantized_weight(sub_module, dtype, weight_name="weight"):
|
|
|
|
|
# Move weight to CPU (stale GPU → safe CPU)
|
|
|
|
|
weight = getattr(sub_module, weight_name, None)
|
|
|
|
|
if weight is not None and isinstance(weight, torch.Tensor) and weight.is_cuda:
|
|
|
|
|
try:
|
|
|
|
|
weight_cpu = weight.cpu()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
setattr(sub_module, weight_name, torch.nn.Parameter(weight_cpu))
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError) as e:
|
|
|
|
|
print(f" WARNING: weight.cpu() failed for {weight_name} ({e})")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# Force all quantizer state to CPU
|
|
|
|
|
qattrs = _quantizer_attr_names(weight_name)
|
|
|
|
|
for qattr in [qattrs.weight_quantizer, qattrs.input_quantizer, qattrs.output_quantizer]:
|
|
|
|
|
if not qattr:
|
|
|
|
|
continue
|
|
|
|
|
quantizer = getattr(sub_module, qattr, None)
|
|
|
|
|
if quantizer is None:
|
|
|
|
|
continue
|
|
|
|
|
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
|
|
|
|
|
val = getattr(quantizer, attr, None)
|
|
|
|
|
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
|
|
|
|
|
try:
|
|
|
|
|
setattr(quantizer, attr, val.cpu())
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
# Handle SequentialQuantizer (W4A8 path)
|
|
|
|
|
if hasattr(quantizer, 'quantizers'):
|
|
|
|
|
for sub_q in quantizer.quantizers:
|
|
|
|
|
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
|
|
|
|
|
val = getattr(sub_q, attr, None)
|
|
|
|
|
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
|
|
|
|
|
try:
|
|
|
|
|
setattr(sub_q, attr, val.cpu())
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
return orig_export_quantized_weight(sub_module, dtype, weight_name)
|
|
|
|
|
|
|
|
|
|
uehf._export_quantized_weight = patched_export_quantized_weight
|
|
|
|
|
print("✓ Patch 4: _export_quantized_weight → force weight + quantizer state to CPU")
|
|
|
|
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
# Patch 5: _export_fused_experts — same treatment for MoE expert weights
|
|
|
|
|
# DeepseekV4Experts go through this different code path.
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
orig_export_fused_experts = uehf._export_fused_experts
|
|
|
|
|
|
|
|
|
|
def patched_export_fused_experts(sub_module, dtype):
|
|
|
|
|
# Force all expert weights to CPU
|
|
|
|
|
for name, param in list(sub_module.named_parameters()):
|
|
|
|
|
if isinstance(param, torch.Tensor) and param.is_cuda:
|
|
|
|
|
try:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
setattr(sub_module, name, torch.nn.Parameter(param.cpu()))
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
# Force all buffers to CPU
|
|
|
|
|
for name, buf in list(sub_module.named_buffers()):
|
|
|
|
|
if isinstance(buf, torch.Tensor) and buf.is_cuda:
|
|
|
|
|
try:
|
|
|
|
|
sub_module.register_buffer(name, buf.cpu())
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
# Force all quantizer state to CPU
|
|
|
|
|
for mod in sub_module.modules():
|
|
|
|
|
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
|
|
|
|
|
val = getattr(mod, attr, None)
|
|
|
|
|
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
|
|
|
|
|
try:
|
|
|
|
|
setattr(mod, attr, val.cpu())
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
return orig_export_fused_experts(sub_module, dtype)
|
|
|
|
|
|
|
|
|
|
uehf._export_fused_experts = patched_export_fused_experts
|
|
|
|
|
print("✓ Patch 5: _export_fused_experts → force expert weights + quantizer state to CPU")
|
|
|
|
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
# Patch 6: to_quantized_weight — force scaling factors to CPU
|
|
|
|
|
# This does .to(weight.device) on scaling factors. With weight now on
|
|
|
|
|
# CPU (Patch 4), this should be a no-op, but belt-and-suspenders.
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
orig_to_quantized_weight = quant_utils.to_quantized_weight
|
|
|
|
|
|
|
|
|
|
def patched_to_quantized_weight(weight, weights_scaling_factor, quantization,
|
|
|
|
|
weights_scaling_factor2=None, block_size=None):
|
|
|
|
|
if isinstance(weight, torch.Tensor) and weight.is_cuda:
|
|
|
|
|
weight = weight.cpu()
|
|
|
|
|
if weights_scaling_factor is not None and isinstance(weights_scaling_factor, torch.Tensor) and weights_scaling_factor.is_cuda:
|
|
|
|
|
weights_scaling_factor = weights_scaling_factor.cpu()
|
|
|
|
|
if weights_scaling_factor2 is not None and isinstance(weights_scaling_factor2, torch.Tensor) and weights_scaling_factor2.is_cuda:
|
|
|
|
|
weights_scaling_factor2 = weights_scaling_factor2.cpu()
|
|
|
|
|
return orig_to_quantized_weight(weight, weights_scaling_factor, quantization,
|
|
|
|
|
weights_scaling_factor2, block_size)
|
|
|
|
|
|
|
|
|
|
quant_utils.to_quantized_weight = patched_to_quantized_weight
|
|
|
|
|
print("✓ Patch 6: to_quantized_weight → force all tensors to CPU")
|
|
|
|
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
# Patch 7: get_weight_scaling_factor — force weight + quantizer to CPU
|
|
|
|
|
# Belt and suspenders: Patch 4 should handle this, but this is also
|
|
|
|
|
# called from other code paths.
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
2026-05-09 22:43:48 +00:00
|
|
|
orig_get_weight_scaling_factor = quant_utils.get_weight_scaling_factor
|
|
|
|
|
|
|
|
|
|
def patched_get_weight_scaling_factor(module, weight_name="weight"):
|
2026-05-09 22:50:58 +00:00
|
|
|
weight = getattr(module, weight_name, None)
|
|
|
|
|
if weight is not None and isinstance(weight, torch.Tensor) and weight.is_cuda:
|
2026-05-09 22:43:48 +00:00
|
|
|
try:
|
|
|
|
|
with torch.no_grad():
|
2026-05-09 22:50:58 +00:00
|
|
|
setattr(module, weight_name, torch.nn.Parameter(weight.cpu()))
|
2026-05-09 22:43:48 +00:00
|
|
|
except (torch.cuda.CudaError, RuntimeError) as e:
|
2026-05-09 22:50:58 +00:00
|
|
|
print(f" WARNING: weight.cpu() failed in get_weight_scaling_factor ({e})")
|
|
|
|
|
raise
|
2026-05-09 22:43:48 +00:00
|
|
|
weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None)
|
|
|
|
|
if weight_quantizer is not None:
|
|
|
|
|
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
|
2026-05-09 22:50:58 +00:00
|
|
|
val = getattr(weight_quantizer, attr, None)
|
|
|
|
|
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
|
|
|
|
|
try:
|
|
|
|
|
setattr(weight_quantizer, attr, val.cpu())
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError):
|
|
|
|
|
pass
|
2026-05-09 22:43:48 +00:00
|
|
|
return orig_get_weight_scaling_factor(module, weight_name)
|
|
|
|
|
|
|
|
|
|
quant_utils.get_weight_scaling_factor = patched_get_weight_scaling_factor
|
2026-05-09 22:50:58 +00:00
|
|
|
print("✓ Patch 7: get_weight_scaling_factor → force weight + quantizer to CPU")
|
2026-05-09 22:43:48 +00:00
|
|
|
|
2026-05-09 22:50:58 +00:00
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
|
|
|
|
# Patch 8: get_weight_scaling_factor_2 — force quantizer to CPU
|
|
|
|
|
# ══════════════════════════════════════════════════════════════════════
|
2026-05-09 22:43:48 +00:00
|
|
|
orig_get_weight_scaling_factor_2 = quant_utils.get_weight_scaling_factor_2
|
|
|
|
|
|
|
|
|
|
def patched_get_weight_scaling_factor_2(module, weight_name="weight"):
|
|
|
|
|
weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None)
|
|
|
|
|
if weight_quantizer is not None:
|
|
|
|
|
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
|
2026-05-09 22:50:58 +00:00
|
|
|
val = getattr(weight_quantizer, attr, None)
|
|
|
|
|
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
|
|
|
|
|
try:
|
|
|
|
|
setattr(weight_quantizer, attr, val.cpu())
|
|
|
|
|
except (torch.cuda.CudaError, RuntimeError):
|
|
|
|
|
pass
|
2026-05-09 22:43:48 +00:00
|
|
|
return orig_get_weight_scaling_factor_2(module, weight_name)
|
|
|
|
|
|
|
|
|
|
quant_utils.get_weight_scaling_factor_2 = patched_get_weight_scaling_factor_2
|
2026-05-09 22:50:58 +00:00
|
|
|
print("✓ Patch 8: get_weight_scaling_factor_2 → force quantizer to CPU")
|
2026-05-09 22:43:48 +00:00
|
|
|
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 06:31:08 +00:00
|
|
|
def snapshot_amax_to_cpu(model, snapshot_path):
|
2026-05-09 14:52:02 +00:00
|
|
|
"""Walk all quantizers, copy _amax to CPU, save to disk."""
|
2026-05-09 06:08:35 +00:00
|
|
|
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
|
2026-05-09 06:31:08 +00:00
|
|
|
|
|
|
|
|
print(f"\nSnapshotting quantizer _amax to CPU...")
|
|
|
|
|
t0 = time.time()
|
|
|
|
|
snapshots = {}
|
|
|
|
|
n_moved = 0
|
|
|
|
|
|
|
|
|
|
for name, module in model.named_modules():
|
|
|
|
|
if not isinstance(module, TensorQuantizer):
|
|
|
|
|
continue
|
|
|
|
|
if hasattr(module, '_amax') and module._amax is not None:
|
|
|
|
|
amax_cpu = module._amax.detach().cpu().clone()
|
|
|
|
|
snapshots[name] = amax_cpu
|
|
|
|
|
module._amax.data.copy_(amax_cpu)
|
|
|
|
|
n_moved += 1
|
|
|
|
|
|
|
|
|
|
torch.save(snapshots, snapshot_path)
|
|
|
|
|
size_mb = os.path.getsize(snapshot_path) / (1024**2)
|
|
|
|
|
print(f"✓ Snapshotted {n_moved} quantizer _amax tensors to CPU ({time.time()-t0:.1f}s)")
|
|
|
|
|
print(f" Saved to: {snapshot_path} ({size_mb:.1f} MB)")
|
|
|
|
|
return snapshots
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def restore_amax_from_snapshot(model, snapshot_path):
|
2026-05-09 09:26:23 +00:00
|
|
|
"""Restore _amax from a previously saved CPU snapshot."""
|
2026-05-09 06:31:08 +00:00
|
|
|
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
|
|
|
|
|
|
|
|
|
|
print(f"Restoring _amax from snapshot: {snapshot_path}")
|
|
|
|
|
snapshots = torch.load(snapshot_path, map_location='cpu')
|
|
|
|
|
n_restored = 0
|
|
|
|
|
|
|
|
|
|
for name, module in model.named_modules():
|
|
|
|
|
if not isinstance(module, TensorQuantizer):
|
|
|
|
|
continue
|
|
|
|
|
if name in snapshots and hasattr(module, '_amax'):
|
|
|
|
|
module._amax.data.copy_(snapshots[name].to(module._amax.device))
|
|
|
|
|
n_restored += 1
|
|
|
|
|
|
|
|
|
|
print(f"✓ Restored {n_restored} _amax tensors from snapshot")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def force_all_amax_to_cpu(model):
|
2026-05-09 14:52:02 +00:00
|
|
|
"""Force ALL quantizer tensors to CPU."""
|
2026-05-09 06:31:08 +00:00
|
|
|
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
|
|
|
|
|
|
2026-05-09 06:07:22 +00:00
|
|
|
count = 0
|
|
|
|
|
for name, module in model.named_modules():
|
2026-05-09 06:31:08 +00:00
|
|
|
if not isinstance(module, TensorQuantizer):
|
|
|
|
|
continue
|
|
|
|
|
for attr in ['_amax', '_pre_quant_scale', '_global_amax']:
|
|
|
|
|
if hasattr(module, attr):
|
|
|
|
|
val = getattr(module, attr)
|
|
|
|
|
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
|
|
|
|
|
setattr(module, attr, val.cpu())
|
2026-05-09 06:08:35 +00:00
|
|
|
count += 1
|
2026-05-09 06:31:08 +00:00
|
|
|
print(f"✓ Forced {count} quantizer tensors to CPU")
|
2026-05-09 06:07:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_calibrated_state(model, path):
|
2026-05-09 09:26:23 +00:00
|
|
|
"""Save model state dict after calibration."""
|
2026-05-09 06:07:22 +00:00
|
|
|
print(f"\n{'='*60}")
|
|
|
|
|
print(f"SAVING CALIBRATED STATE → {path}")
|
|
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
|
state = {
|
|
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
|
|
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
|
|
}
|
|
|
|
|
torch.save(state, path)
|
|
|
|
|
size_gb = os.path.getsize(path) / (1024**3)
|
|
|
|
|
print(f"✓ Saved calibrated state: {size_gb:.1f} GB ({time.time()-start:.0f}s)")
|
|
|
|
|
print(f" Path: {path}")
|
2026-05-09 06:31:08 +00:00
|
|
|
print(f" Re-run with --export-only to retry export.\n")
|
|
|
|
|
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 06:31:08 +00:00
|
|
|
def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path, calib_size, calib_seq):
|
2026-05-09 14:52:02 +00:00
|
|
|
"""Full pipeline: parse args via hf_ptq → load → quantize → snapshot → save → export."""
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 06:08:35 +00:00
|
|
|
os.chdir(EXAMPLE_DIR)
|
|
|
|
|
sys.path.insert(0, EXAMPLE_DIR)
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 06:08:35 +00:00
|
|
|
os.environ["HF_TOKEN"] = HF_TOKEN
|
|
|
|
|
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
|
|
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
from hf_ptq import parse_args, main as hf_main
|
2026-05-09 06:07:22 +00:00
|
|
|
|
|
|
|
|
apply_patches()
|
|
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
# ── Build args using hf_ptq's own parser ──
|
|
|
|
|
# This guarantees ALL attributes exist with correct defaults.
|
|
|
|
|
# We temporarily replace sys.argv so parse_args() sees our config.
|
|
|
|
|
saved_argv = sys.argv
|
|
|
|
|
sys.argv = [
|
|
|
|
|
"hf_ptq.py",
|
2026-05-09 14:57:28 +00:00
|
|
|
"--pyt_ckpt_path", model_path,
|
|
|
|
|
"--qformat", QUANT,
|
2026-05-09 14:54:51 +00:00
|
|
|
"--calib_size", str(calib_size),
|
2026-05-09 14:52:02 +00:00
|
|
|
"--calib_seq", str(calib_seq),
|
2026-05-09 14:54:51 +00:00
|
|
|
"--kv_cache_qformat", KV_CACHE_QUANT,
|
2026-05-09 14:57:28 +00:00
|
|
|
"--inference_tensor_parallel", str(TP),
|
2026-05-09 14:52:02 +00:00
|
|
|
"--export_path", export_dir,
|
|
|
|
|
"--trust_remote_code",
|
|
|
|
|
"--use_seq_device_map",
|
|
|
|
|
"--gpu_max_mem_percentage", str(GPU_MEM_PCT),
|
|
|
|
|
"--batch_size", "0",
|
|
|
|
|
]
|
|
|
|
|
args = parse_args()
|
|
|
|
|
sys.argv = saved_argv
|
|
|
|
|
|
2026-05-09 15:58:36 +00:00
|
|
|
# Apply the same post-parse conversions that hf_ptq's __main__ block does
|
|
|
|
|
# (these normally run between parse_args() and main() in the original script,
|
|
|
|
|
# but since we call main() directly, we have to do them ourselves)
|
|
|
|
|
args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset
|
|
|
|
|
args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")]
|
|
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
# ── Post-calibration hook ──
|
|
|
|
|
# We monkey-patch export_quantized to add our defensive saves before export.
|
|
|
|
|
import hf_ptq
|
|
|
|
|
|
|
|
|
|
orig_export_quantized = hf_ptq.export_quantized
|
|
|
|
|
|
|
|
|
|
def patched_export_quantized(exp_args, full_model, language_model, model_type,
|
|
|
|
|
tokenizer, default_padding_side, default_pad_token):
|
|
|
|
|
"""Wrapper that snapshots amax and saves state before calling the real export."""
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
print("POST-CALIBRATION: Snapshotting amax and saving state")
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
# Snapshot amax to CPU
|
|
|
|
|
snapshot_amax_to_cpu(language_model, amax_snapshot_path)
|
|
|
|
|
|
|
|
|
|
# Force all quantizer state to CPU
|
|
|
|
|
force_all_amax_to_cpu(language_model)
|
|
|
|
|
|
|
|
|
|
# Free GPU memory
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
# Save calibrated state
|
|
|
|
|
save_calibrated_state(language_model, calib_save_path)
|
|
|
|
|
|
|
|
|
|
# Now run the real export
|
|
|
|
|
orig_export_quantized(exp_args, full_model, language_model, model_type,
|
|
|
|
|
tokenizer, default_padding_side, default_pad_token)
|
|
|
|
|
|
|
|
|
|
hf_ptq.export_quantized = patched_export_quantized
|
|
|
|
|
print("✓ Hooked export_quantized with amax snapshot + state save")
|
|
|
|
|
|
|
|
|
|
# ── Run hf_ptq's full pipeline ──
|
|
|
|
|
# This handles model loading, quantization, calibration, and export
|
|
|
|
|
# using the exact same code path as the shell script.
|
|
|
|
|
hf_main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_export_only(calib_save_path, amax_snapshot_path, model_path, export_dir):
|
|
|
|
|
"""Load saved calibration state and run export only."""
|
|
|
|
|
|
|
|
|
|
os.chdir(EXAMPLE_DIR)
|
|
|
|
|
sys.path.insert(0, EXAMPLE_DIR)
|
|
|
|
|
|
|
|
|
|
os.environ["HF_TOKEN"] = HF_TOKEN
|
|
|
|
|
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
apply_patches()
|
|
|
|
|
|
|
|
|
|
from example_utils import get_model, get_tokenizer
|
|
|
|
|
|
|
|
|
|
print(f"Loading model from {model_path}...")
|
2026-05-09 09:26:23 +00:00
|
|
|
model = get_model(
|
2026-05-09 06:07:22 +00:00
|
|
|
model_path,
|
2026-05-09 14:52:02 +00:00
|
|
|
device="cpu",
|
2026-05-09 06:07:22 +00:00
|
|
|
trust_remote_code=True,
|
|
|
|
|
)
|
2026-05-09 08:00:50 +00:00
|
|
|
tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
2026-05-09 06:07:22 +00:00
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
print(f"Loading calibrated state from {calib_save_path}...")
|
|
|
|
|
state = torch.load(calib_save_path, map_location='cpu')
|
|
|
|
|
model.load_state_dict(state['model_state_dict'])
|
|
|
|
|
print(f"✓ Loaded calibrated state (saved at {state['timestamp']})")
|
2026-05-09 06:31:08 +00:00
|
|
|
|
|
|
|
|
force_all_amax_to_cpu(model)
|
2026-05-09 14:52:02 +00:00
|
|
|
if amax_snapshot_path and os.path.exists(amax_snapshot_path):
|
|
|
|
|
restore_amax_from_snapshot(model, amax_snapshot_path)
|
2026-05-09 06:31:08 +00:00
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
2026-05-09 06:07:22 +00:00
|
|
|
from modelopt.torch.export import export_hf_checkpoint
|
2026-05-09 06:08:35 +00:00
|
|
|
from hf_ptq import load_mtp_weights, copy_custom_model_files
|
2026-05-09 06:07:22 +00:00
|
|
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
|
|
print(f"EXPORTING → {export_dir}")
|
|
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
|
|
|
|
t0 = time.time()
|
|
|
|
|
try:
|
|
|
|
|
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(model, model_path)
|
|
|
|
|
if mtp_layer_prefixes:
|
|
|
|
|
model._mtp_layer_prefixes = mtp_layer_prefixes
|
|
|
|
|
|
2026-05-09 14:52:02 +00:00
|
|
|
export_hf_checkpoint(model, export_dir=export_dir, extra_state_dict=mtp_state_dict)
|
2026-05-09 06:07:22 +00:00
|
|
|
tokenizer.save_pretrained(export_dir)
|
|
|
|
|
copy_custom_model_files(model_path, export_dir, True)
|
2026-05-09 06:08:35 +00:00
|
|
|
print(f"\n✓ Export complete in {time.time()-t0:.0f}s → {export_dir}")
|
2026-05-09 06:07:22 +00:00
|
|
|
except Exception as e:
|
|
|
|
|
print(f"\n✗ EXPORT FAILED: {e}")
|
2026-05-09 06:31:08 +00:00
|
|
|
print(f" Calibrated state: {CALIB_SAVE_PATH}")
|
|
|
|
|
print(f" Amax snapshots: {AMAX_SNAPSHOT_PATH}")
|
2026-05-09 06:07:22 +00:00
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
2026-05-09 06:31:08 +00:00
|
|
|
def run_validate(calib_save_path, amax_snapshot_path):
|
|
|
|
|
"""Validate saved calibration state — check amax values are valid."""
|
|
|
|
|
print(f"\nValidating calibration state...")
|
|
|
|
|
|
|
|
|
|
if os.path.exists(amax_snapshot_path):
|
|
|
|
|
snapshots = torch.load(amax_snapshot_path, map_location='cpu')
|
|
|
|
|
n_total = len(snapshots)
|
2026-05-09 14:52:02 +00:00
|
|
|
n_valid = n_zero = n_nan = n_neg = 0
|
2026-05-09 06:31:08 +00:00
|
|
|
|
|
|
|
|
for name, amax in snapshots.items():
|
|
|
|
|
if torch.any(torch.isnan(amax)):
|
|
|
|
|
n_nan += 1
|
|
|
|
|
elif torch.any(amax < 0):
|
|
|
|
|
n_neg += 1
|
|
|
|
|
elif torch.all(amax == 0):
|
|
|
|
|
n_zero += 1
|
|
|
|
|
else:
|
|
|
|
|
n_valid += 1
|
|
|
|
|
|
|
|
|
|
print(f"\nAmax snapshot validation:")
|
2026-05-09 14:52:02 +00:00
|
|
|
print(f" Total: {n_total} Valid: {n_valid} Zero: {n_zero} Neg: {n_neg} NaN: {n_nan}")
|
2026-05-09 06:31:08 +00:00
|
|
|
if n_valid == n_total:
|
|
|
|
|
print(f"\n✓ All {n_total} amax snapshots are valid!")
|
|
|
|
|
else:
|
|
|
|
|
print(f"\n✗ {n_total - n_valid} quantizers have invalid amax!")
|
|
|
|
|
else:
|
|
|
|
|
print(f"✗ No amax snapshot found at {amax_snapshot_path}")
|
|
|
|
|
|
|
|
|
|
if os.path.exists(calib_save_path):
|
|
|
|
|
size_gb = os.path.getsize(calib_save_path) / (1024**3)
|
|
|
|
|
print(f"\nCalibrated state: {calib_save_path} ({size_gb:.1f} GB)")
|
|
|
|
|
else:
|
|
|
|
|
print(f"\n✗ No calibrated state found at {calib_save_path}")
|
2026-05-09 06:07:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
parser = argparse.ArgumentParser(description="DeepSeek V4 Pro NVFP4 Quantization")
|
|
|
|
|
parser.add_argument("--export-only", action="store_true",
|
|
|
|
|
help="Skip calibration, load saved state and run export only")
|
2026-05-09 08:02:09 +00:00
|
|
|
parser.add_argument("--validate-only", action="store_true",
|
2026-05-09 06:31:08 +00:00
|
|
|
help="Validate saved calibration state without running anything")
|
2026-05-09 06:07:22 +00:00
|
|
|
parser.add_argument("--model", default=MODEL, help="Path to BF16 model")
|
|
|
|
|
parser.add_argument("--export-dir", default=EXPORT_DIR, help="Export output directory")
|
|
|
|
|
parser.add_argument("--calib-save", default=CALIB_SAVE_PATH, help="Calibration state save path")
|
2026-05-09 06:31:08 +00:00
|
|
|
parser.add_argument("--amax-snapshot", default=AMAX_SNAPSHOT_PATH, help="Amax snapshot path")
|
2026-05-09 06:07:22 +00:00
|
|
|
parser.add_argument("--calib-size", type=int, default=CALIB_SIZE, help="Calibration samples")
|
|
|
|
|
parser.add_argument("--calib-seq", type=int, default=CALIB_SEQ, help="Calibration sequence length")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
2026-05-09 06:31:08 +00:00
|
|
|
if args.validate_only:
|
|
|
|
|
run_validate(args.calib_save, args.amax_snapshot)
|
|
|
|
|
elif args.export_only:
|
2026-05-09 06:07:22 +00:00
|
|
|
if not os.path.exists(args.calib_save):
|
|
|
|
|
print(f"ERROR: No calibration state found at {args.calib_save}")
|
|
|
|
|
sys.exit(1)
|
2026-05-09 06:31:08 +00:00
|
|
|
run_export_only(args.calib_save, args.amax_snapshot, args.model, args.export_dir)
|
2026-05-09 06:07:22 +00:00
|
|
|
else:
|
2026-05-09 06:08:35 +00:00
|
|
|
run_calibration(args.model, args.export_dir, args.calib_save,
|
2026-05-09 06:31:08 +00:00
|
|
|
args.amax_snapshot, args.calib_size, args.calib_seq)
|
2026-05-09 06:07:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|