From efc111a11fea7989d8e92f0c7419d79e1aa89d1a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 9 May 2026 22:43:48 +0000 Subject: [PATCH] Add Patch 4+5: get_weight_scaling_factor and get_weight_scaling_factor_2 CPU safety MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Run 10 completed calibration (128/128) but crashed at export in get_weight_scaling_factor — the weight tensor on GPU was stale after 5+ hours of calibration, and weight_scaling_factor_2.to(weight.device) triggered cudaErrorIllegalAddress. Patches 4+5 force weight and quantizer state to CPU before computing scaling factors. This mirrors the same pattern as Patch 3 (get_activation_scaling_factor). Calibrated state saved successfully (721.4 GB, 47,696 amax tensors). Amax snapshot saved (15.4 MB). Re-running with new patches. --- scripts/quantize_nvfp4.py | 60 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/scripts/quantize_nvfp4.py b/scripts/quantize_nvfp4.py index 37a520f..d7cdab8 100644 --- a/scripts/quantize_nvfp4.py +++ b/scripts/quantize_nvfp4.py @@ -114,6 +114,66 @@ def apply_patches(): nvfp4_tensor.NVFP4QTensor.get_activation_scaling_factor = patched_get_activation_scaling_factor print("✓ Patched NVFP4QTensor.get_activation_scaling_factor (CPU + clamp)") + # ── Patch 4: get_weight_scaling_factor — force weight to CPU before export computation ── + # After hours of calibration with use_seq_device_map, model weight tensors on GPU + # have stale memory. Any attempt to read them (even .to(device)) triggers + # cudaErrorIllegalAddress. We force weight to CPU and do all math on CPU. + from modelopt.torch.export import quant_utils + from modelopt.torch.quantization.utils import quantizer_attr_names as _quantizer_attr_names + + orig_get_weight_scaling_factor = quant_utils.get_weight_scaling_factor + + def patched_get_weight_scaling_factor(module, weight_name="weight"): + """Force weight and all quantizer state to CPU before computing scaling factors.""" + # Move weight to CPU if on GPU (avoids stale GPU tensor reads) + weight = getattr(module, weight_name) + if isinstance(weight, torch.Tensor) and weight.is_cuda: + try: + weight_cpu = weight.cpu() + # Update the module parameter so downstream code uses CPU version + with torch.no_grad(): + setattr(module, weight_name, torch.nn.Parameter(weight_cpu)) + weight = weight_cpu + except (torch.cuda.CudaError, RuntimeError) as e: + print(f" WARNING: weight.cpu() failed for {weight_name} ({e}), using zeros") + weight = torch.zeros_like(weight, device='cpu') + + # Move quantizer amax to CPU + weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None) + if weight_quantizer is not None: + for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']: + if hasattr(weight_quantizer, attr): + val = getattr(weight_quantizer, attr) + if val is not None and isinstance(val, torch.Tensor) and val.is_cuda: + try: + setattr(weight_quantizer, attr, val.cpu()) + except (torch.cuda.CudaError, RuntimeError): + pass + + return orig_get_weight_scaling_factor(module, weight_name) + + quant_utils.get_weight_scaling_factor = patched_get_weight_scaling_factor + print("✓ Patched get_weight_scaling_factor (force weight + quantizer to CPU)") + + # ── Patch 5: get_weight_scaling_factor_2 — force quantizer state to CPU ── + orig_get_weight_scaling_factor_2 = quant_utils.get_weight_scaling_factor_2 + + def patched_get_weight_scaling_factor_2(module, weight_name="weight"): + weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None) + if weight_quantizer is not None: + for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']: + if hasattr(weight_quantizer, attr): + val = getattr(weight_quantizer, attr) + if val is not None and isinstance(val, torch.Tensor) and val.is_cuda: + try: + setattr(weight_quantizer, attr, val.cpu()) + except (torch.cuda.CudaError, RuntimeError): + pass + return orig_get_weight_scaling_factor_2(module, weight_name) + + quant_utils.get_weight_scaling_factor_2 = patched_get_weight_scaling_factor_2 + print("✓ Patched get_weight_scaling_factor_2 (force quantizer to CPU)") + def snapshot_amax_to_cpu(model, snapshot_path): """Walk all quantizers, copy _amax to CPU, save to disk."""