diff --git a/scripts/quantize_nvfp4.py b/scripts/quantize_nvfp4.py index 48be183..fd13fc7 100644 --- a/scripts/quantize_nvfp4.py +++ b/scripts/quantize_nvfp4.py @@ -56,44 +56,28 @@ EXPORT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" CALIB_SAVE_PATH = "/root/nvidia-meeting/v4_nvfp4_calibrated_state.pt" AMAX_SNAPSHOT_PATH = "/root/nvidia-meeting/v4_nvfp4_amax_snapshots.pt" -# Flag: when True, force all new _amax writes to CPU -_FORCE_AMAX_CPU = False - def apply_patches(): """Apply runtime patches for V4 compatibility and GPU tensor safety.""" from modelopt.torch.quantization.nn.modules import tensor_quantizer as tq_module - # ── Patch 1: Force _amax to CPU after calibration completes ── + # ── Patch 1: load_calib_amax — force _amax to CPU after calibration ── # - # The _amax property setter is called by load_calib_amax() at the end of - # calibration. By default it stores on GPU. We patch it so that when - # _FORCE_AMAX_CPU is True, _amax goes to CPU instead. - # - # During calibration (before the flag is set), _amax stays on GPU for - # fake quantization. After calibration, we set the flag and re-call - # load_calib_amax() to re-populate _amax on CPU. + # load_calib_amax() is called by max_calibrate() after the forward loop + # finishes. It writes _amax to GPU by default. We patch it so _amax + # goes to CPU instead, preventing GPU corruption during the long wait + # before export. + orig_load_calib_amax = tq_module.TensorQuantizer.load_calib_amax - orig_amax_setter = tq_module.TensorQuantizer.amax.fset + def patched_load_calib_amax(self, *args, **kwargs): + orig_load_calib_amax(self, *args, **kwargs) + # After _amax is written, move it to CPU + if hasattr(self, '_amax') and self._amax is not None: + self._amax = self._amax.cpu() - def patched_amax_setter(self, value): - assert value is not None, "amax cannot be set to None." - if not isinstance(value, torch.Tensor): - value = torch.tensor(value) - if not hasattr(self, "_amax"): - if _FORCE_AMAX_CPU: - self.register_buffer("_amax", value.clone().detach().cpu()) - else: - self.register_buffer("_amax", value.clone().detach()) - else: - if self._amax.shape != value.shape: - raise RuntimeError("Changing shape when setting amax is not allowed.") - target = self._amax.cpu() if _FORCE_AMAX_CPU else self._amax - self._amax.data.copy_(value.clone().detach().to(target.device)) - - tq_module.TensorQuantizer.amax.fset = patched_amax_setter - print("✓ Patched TensorQuantizer.amax setter (CPU mode controlled by _FORCE_AMAX_CPU)") + tq_module.TensorQuantizer.load_calib_amax = patched_load_calib_amax + print("✓ Patched TensorQuantizer.load_calib_amax (force _amax to CPU)") # ── Patch 2: export_amax — CPU safety ── # If any _amax is still on GPU at export time, move it before reading. @@ -342,10 +326,7 @@ def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path, # After snapshotting, force remaining GPU tensors to CPU too force_all_amax_to_cpu(model) - # ── Enable CPU mode for any future amax writes ── - _FORCE_AMAX_CPU = True - - # ── Free GPU memory ── + # ── Force ALL quantizer state to CPU ── torch.cuda.empty_cache() gc.collect() @@ -402,9 +383,6 @@ def run_export(model, tokenizer, model_path, export_dir, amax_snapshot_path=None def run_export_only(calib_save_path, amax_snapshot_path, model_path, export_dir): """Load saved calibration state and run export only.""" - global _FORCE_AMAX_CPU - _FORCE_AMAX_CPU = True # Force CPU for any amax writes - os.chdir(EXAMPLE_DIR) sys.path.insert(0, EXAMPLE_DIR)