Clean rewrite: verified all imports against runtime, removed dead code

- get_model/get_tokenizer imported from example_utils (not hf_ptq)
- KV_QUANT_CFG_CHOICES imported from hf_ptq (not mtq)
- Removed dead _FORCE_AMAX_CPU global and global reference in run_export_only
- Fixed stale comments
- All 16 imports and references verified against the actual B200 runtime
- Zero divergences from modelopt example path except get_model()
This commit is contained in:
2026-05-09 09:26:23 +00:00
parent 86dd8df302
commit 6c1bff6997

View File

@@ -61,18 +61,18 @@ def apply_patches():
"""Apply runtime patches for V4 compatibility and GPU tensor safety."""
from modelopt.torch.quantization.nn.modules import tensor_quantizer as tq_module
from modelopt.torch.quantization.qtensor import nvfp4_tensor
# ── Patch 1: load_calib_amax — force _amax to CPU after calibration ──
#
# load_calib_amax() is called by max_calibrate() after the forward loop
# finishes. It writes _amax to GPU by default. We patch it so _amax
# goes to CPU instead, preventing GPU corruption during the long wait
# before export.
# goes to CPU immediately, preventing GPU corruption during the long
# wait before export.
orig_load_calib_amax = tq_module.TensorQuantizer.load_calib_amax
def patched_load_calib_amax(self, *args, **kwargs):
orig_load_calib_amax(self, *args, **kwargs)
# After _amax is written, move it to CPU
if hasattr(self, '_amax') and self._amax is not None:
self._amax = self._amax.cpu()
@@ -92,8 +92,6 @@ def apply_patches():
print("✓ Patched TensorQuantizer.export_amax (CPU fallback)")
# ── Patch 3: NVFP4QTensor.get_activation_scaling_factor — graceful degradation ──
from modelopt.torch.quantization.qtensor import nvfp4_tensor
@classmethod
def patched_get_activation_scaling_factor(cls, quantizer):
if not quantizer.is_enabled:
@@ -111,8 +109,6 @@ def apply_patches():
amax = amax.cpu()
activation_scaling_factor = amax.float() / (quantizer.maxbound * 448.0)
# Clamp instead of hard assert — bad values from GPU corruption should
# not kill the entire 6-hour run
if not torch.all(activation_scaling_factor > 0):
n_bad = (activation_scaling_factor <= 0).sum().item()
n_total = activation_scaling_factor.numel()
@@ -128,13 +124,10 @@ def apply_patches():
def snapshot_amax_to_cpu(model, snapshot_path):
"""Walk all quantizers, copy their _amax to CPU, save to disk.
This is the core defensive measure. After calibration completes, the _amax
tensors are fresh and valid on GPU. We copy them to CPU immediately and
save to disk. This costs almost nothing (~50MB for ~49K quantizers) but
guarantees we have valid calibration data even if CUDA corrupts the GPU
copies later.
Returns the snapshot dict: {quantizer_name: amax_tensor_on_cpu}
After calibration completes, the _amax tensors are fresh and valid on GPU.
We copy them to CPU immediately and save to disk. This costs almost nothing
(~50MB for ~49K quantizers) but guarantees we have valid calibration data
even if CUDA corrupts the GPU copies later.
"""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
@@ -147,14 +140,11 @@ def snapshot_amax_to_cpu(model, snapshot_path):
if not isinstance(module, TensorQuantizer):
continue
if hasattr(module, '_amax') and module._amax is not None:
# Copy to CPU immediately
amax_cpu = module._amax.detach().cpu().clone()
snapshots[name] = amax_cpu
# Replace the GPU copy with the CPU copy
module._amax.data.copy_(amax_cpu)
n_moved += 1
# Save snapshots to disk
torch.save(snapshots, snapshot_path)
size_mb = os.path.getsize(snapshot_path) / (1024**2)
print(f"✓ Snapshotted {n_moved} quantizer _amax tensors to CPU ({time.time()-t0:.1f}s)")
@@ -164,11 +154,7 @@ def snapshot_amax_to_cpu(model, snapshot_path):
def restore_amax_from_snapshot(model, snapshot_path):
"""Restore _amax from a previously saved CPU snapshot.
Used by --export-only to guarantee valid amax values even if the
model state dict has corrupted GPU tensors.
"""
"""Restore _amax from a previously saved CPU snapshot."""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
print(f"Restoring _amax from snapshot: {snapshot_path}")
@@ -186,11 +172,7 @@ def restore_amax_from_snapshot(model, snapshot_path):
def force_all_amax_to_cpu(model):
"""Force ALL quantizer tensors to CPU. Nuclear option after calibration.
After calling this, no quantizer _amax lives on GPU. Export can't hit
CUDA illegal memory access because there's nothing on GPU to corrupt.
"""
"""Force ALL quantizer tensors to CPU. Nuclear option after calibration."""
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
count = 0
@@ -207,18 +189,13 @@ def force_all_amax_to_cpu(model):
def save_calibrated_state(model, path):
"""Save model state dict after calibration.
The insurance policy: if export crashes, we can reload and retry
without re-running 6 hours of calibration.
"""
"""Save model state dict after calibration."""
print(f"\n{'='*60}")
print(f"SAVING CALIBRATED STATE → {path}")
print(f"{'='*60}")
start = time.time()
# All quantizer state should already be on CPU from snapshot_amax_to_cpu
state = {
'model_state_dict': model.state_dict(),
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
@@ -234,35 +211,36 @@ def save_calibrated_state(model, path):
def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path, calib_size, calib_seq):
"""Full pipeline: load → quantize → calibrate → snapshot → save → export."""
global _FORCE_AMAX_CPU
os.chdir(EXAMPLE_DIR)
sys.path.insert(0, EXAMPLE_DIR)
os.environ["HF_TOKEN"] = HF_TOKEN
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
# Import from hf_ptq and modelopt — all verified against the example script
from example_utils import get_model, get_tokenizer
from hf_ptq import (
get_model as modelopt_get_model, get_tokenizer, make_calib_dataloader,
build_quant_cfg, load_mtp_weights, copy_custom_model_files,
QUANT_CFG_CHOICES, KV_QUANT_CFG_CHOICES,
make_calib_dataloader,
build_quant_cfg,
load_mtp_weights,
copy_custom_model_files,
QUANT_CFG_CHOICES,
KV_QUANT_CFG_CHOICES,
)
from modelopt.torch import quantization as mtq
from modelopt.torch.quantization.config import need_calibration
from modelopt.torch.utils.dataset_utils import get_max_batch_size
from modelopt.torch.export import export_hf_checkpoint
from transformers import AutoTokenizer
apply_patches()
# ── Load model ──
# Use modelopt's get_model() instead of raw AutoModelForCausalLM.from_pretrained.
# The raw call OOMs during weight conversion (torch.cat on experts needs 31.5GB,
# only 25.9GB free). modelopt's loader handles max_memory/device_map properly.
# Use modelopt's get_model() — handles max_memory properly for 3TB model.
# Raw AutoModelForCausalLM.from_pretrained OOMs during expert weight conversion.
print(f"\nLoading model from {model_path}...")
t0 = time.time()
model = modelopt_get_model(
model = get_model(
model_path,
gpu_mem_percentage=GPU_MEM_PCT,
trust_remote_code=True,
@@ -272,17 +250,22 @@ def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path,
print(f"✓ Model loaded in {time.time()-t0:.0f}s")
# ── Setup quantization config ──
# Same flow as hf_ptq's quantize_main()
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[QUANT])
quant_cfg = build_quant_cfg(QUANT, quant_cfg, None, None, None)
if KV_CACHE_QUANT != "none":
enable_quant_kv_cache = True
print(f"✓ KV cache quantization: {KV_CACHE_QUANT}")
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[KV_CACHE_QUANT])["quant_cfg"],
)
print(f"✓ KV cache quantization: {KV_CACHE_QUANT}")
else:
enable_quant_kv_cache = False
# ── Detect batch size ──
# Same as hf_ptq's quantize_main()
print("\nDetecting max calibration batch size...")
batch_size = get_max_batch_size(
model,
@@ -293,6 +276,7 @@ def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path,
print(f"✓ Using calibration batch_size={batch_size}")
# ── Prepare dataloader ──
# Same args structure as hf_ptq
args = argparse.Namespace(
calib_size=[calib_size],
calib_seq=calib_seq,
@@ -310,23 +294,17 @@ def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path,
print(f"{'='*60}")
t0 = time.time()
# _FORCE_AMAX_CPU is False during calibration — amax stays on GPU for
# fake quantization during the forward passes
model = mtq.quantize(model, quant_cfg, forward_loop=calib_dataloader)
print(f"✓ Quantization + calibration complete in {time.time()-t0:.0f}s")
# ── IMMEDIATELY snapshot all _amax to CPU ──
# This is the critical defensive step. Right after mtq.quantize() returns,
# the _amax tensors are fresh and valid on GPU. We copy them to CPU NOW,
# before any other GPU operation has a chance to corrupt them.
snapshots = snapshot_amax_to_cpu(model, amax_snapshot_path)
# ── Force ALL quantizer state to CPU ──
# After snapshotting, force remaining GPU tensors to CPU too
force_all_amax_to_cpu(model)
# ── Force ALL quantizer state to CPU ──
# ── Free GPU memory ──
torch.cuda.empty_cache()
gc.collect()
@@ -346,12 +324,10 @@ def run_export(model, tokenizer, model_path, export_dir, amax_snapshot_path=None
print(f"EXPORTING → {export_dir}")
print(f"{'='*60}")
# Ensure all quantizer state is on CPU
force_all_amax_to_cpu(model)
if amax_snapshot_path and os.path.exists(amax_snapshot_path):
restore_amax_from_snapshot(model, amax_snapshot_path)
# Free GPU memory before export
torch.cuda.empty_cache()
gc.collect()
@@ -383,6 +359,7 @@ def run_export(model, tokenizer, model_path, export_dir, amax_snapshot_path=None
def run_export_only(calib_save_path, amax_snapshot_path, model_path, export_dir):
"""Load saved calibration state and run export only."""
os.chdir(EXAMPLE_DIR)
sys.path.insert(0, EXAMPLE_DIR)
@@ -391,10 +368,10 @@ def run_export_only(calib_save_path, amax_snapshot_path, model_path, export_dir)
apply_patches()
from hf_ptq import get_model as modelopt_get_model, get_tokenizer
from example_utils import get_model, get_tokenizer
print(f"Loading model skeleton from {model_path}...")
model = modelopt_get_model(
print(f"Loading model from {model_path}...")
model = get_model(
model_path,
device="cpu",
trust_remote_code=True,
@@ -413,7 +390,6 @@ def run_validate(calib_save_path, amax_snapshot_path):
"""Validate saved calibration state — check amax values are valid."""
print(f"\nValidating calibration state...")
# Check amax snapshots
if os.path.exists(amax_snapshot_path):
snapshots = torch.load(amax_snapshot_path, map_location='cpu')
n_total = len(snapshots)
@@ -446,7 +422,6 @@ def run_validate(calib_save_path, amax_snapshot_path):
else:
print(f"✗ No amax snapshot found at {amax_snapshot_path}")
# Check full state dict
if os.path.exists(calib_save_path):
size_gb = os.path.getsize(calib_save_path) / (1024**3)
print(f"\nCalibrated state: {calib_save_path} ({size_gb:.1f} GB)")