diff --git a/scripts/quantize_nvfp4.py b/scripts/quantize_nvfp4.py index 37a520f..d7cdab8 100644 --- a/scripts/quantize_nvfp4.py +++ b/scripts/quantize_nvfp4.py @@ -114,6 +114,66 @@ def apply_patches(): nvfp4_tensor.NVFP4QTensor.get_activation_scaling_factor = patched_get_activation_scaling_factor print("✓ Patched NVFP4QTensor.get_activation_scaling_factor (CPU + clamp)") + # ── Patch 4: get_weight_scaling_factor — force weight to CPU before export computation ── + # After hours of calibration with use_seq_device_map, model weight tensors on GPU + # have stale memory. Any attempt to read them (even .to(device)) triggers + # cudaErrorIllegalAddress. We force weight to CPU and do all math on CPU. + from modelopt.torch.export import quant_utils + from modelopt.torch.quantization.utils import quantizer_attr_names as _quantizer_attr_names + + orig_get_weight_scaling_factor = quant_utils.get_weight_scaling_factor + + def patched_get_weight_scaling_factor(module, weight_name="weight"): + """Force weight and all quantizer state to CPU before computing scaling factors.""" + # Move weight to CPU if on GPU (avoids stale GPU tensor reads) + weight = getattr(module, weight_name) + if isinstance(weight, torch.Tensor) and weight.is_cuda: + try: + weight_cpu = weight.cpu() + # Update the module parameter so downstream code uses CPU version + with torch.no_grad(): + setattr(module, weight_name, torch.nn.Parameter(weight_cpu)) + weight = weight_cpu + except (torch.cuda.CudaError, RuntimeError) as e: + print(f" WARNING: weight.cpu() failed for {weight_name} ({e}), using zeros") + weight = torch.zeros_like(weight, device='cpu') + + # Move quantizer amax to CPU + weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None) + if weight_quantizer is not None: + for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']: + if hasattr(weight_quantizer, attr): + val = getattr(weight_quantizer, attr) + if val is not None and isinstance(val, torch.Tensor) and val.is_cuda: + try: + setattr(weight_quantizer, attr, val.cpu()) + except (torch.cuda.CudaError, RuntimeError): + pass + + return orig_get_weight_scaling_factor(module, weight_name) + + quant_utils.get_weight_scaling_factor = patched_get_weight_scaling_factor + print("✓ Patched get_weight_scaling_factor (force weight + quantizer to CPU)") + + # ── Patch 5: get_weight_scaling_factor_2 — force quantizer state to CPU ── + orig_get_weight_scaling_factor_2 = quant_utils.get_weight_scaling_factor_2 + + def patched_get_weight_scaling_factor_2(module, weight_name="weight"): + weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None) + if weight_quantizer is not None: + for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']: + if hasattr(weight_quantizer, attr): + val = getattr(weight_quantizer, attr) + if val is not None and isinstance(val, torch.Tensor) and val.is_cuda: + try: + setattr(weight_quantizer, attr, val.cpu()) + except (torch.cuda.CudaError, RuntimeError): + pass + return orig_get_weight_scaling_factor_2(module, weight_name) + + quant_utils.get_weight_scaling_factor_2 = patched_get_weight_scaling_factor_2 + print("✓ Patched get_weight_scaling_factor_2 (force quantizer to CPU)") + def snapshot_amax_to_cpu(model, snapshot_path): """Walk all quantizers, copy _amax to CPU, save to disk."""