Add Patch 4+5: get_weight_scaling_factor and get_weight_scaling_factor_2 CPU safety

Run 10 completed calibration (128/128) but crashed at export in
get_weight_scaling_factor — the weight tensor on GPU was stale after
5+ hours of calibration, and weight_scaling_factor_2.to(weight.device)
triggered cudaErrorIllegalAddress.

Patches 4+5 force weight and quantizer state to CPU before computing
scaling factors. This mirrors the same pattern as Patch 3
(get_activation_scaling_factor).

Calibrated state saved successfully (721.4 GB, 47,696 amax tensors).
Amax snapshot saved (15.4 MB). Re-running with new patches.
This commit is contained in:
2026-05-09 22:43:48 +00:00
parent ce9056d259
commit efc111a11f

View File

@@ -114,6 +114,66 @@ def apply_patches():
nvfp4_tensor.NVFP4QTensor.get_activation_scaling_factor = patched_get_activation_scaling_factor
print("✓ Patched NVFP4QTensor.get_activation_scaling_factor (CPU + clamp)")
# ── Patch 4: get_weight_scaling_factor — force weight to CPU before export computation ──
# After hours of calibration with use_seq_device_map, model weight tensors on GPU
# have stale memory. Any attempt to read them (even .to(device)) triggers
# cudaErrorIllegalAddress. We force weight to CPU and do all math on CPU.
from modelopt.torch.export import quant_utils
from modelopt.torch.quantization.utils import quantizer_attr_names as _quantizer_attr_names
orig_get_weight_scaling_factor = quant_utils.get_weight_scaling_factor
def patched_get_weight_scaling_factor(module, weight_name="weight"):
"""Force weight and all quantizer state to CPU before computing scaling factors."""
# Move weight to CPU if on GPU (avoids stale GPU tensor reads)
weight = getattr(module, weight_name)
if isinstance(weight, torch.Tensor) and weight.is_cuda:
try:
weight_cpu = weight.cpu()
# Update the module parameter so downstream code uses CPU version
with torch.no_grad():
setattr(module, weight_name, torch.nn.Parameter(weight_cpu))
weight = weight_cpu
except (torch.cuda.CudaError, RuntimeError) as e:
print(f" WARNING: weight.cpu() failed for {weight_name} ({e}), using zeros")
weight = torch.zeros_like(weight, device='cpu')
# Move quantizer amax to CPU
weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None)
if weight_quantizer is not None:
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
if hasattr(weight_quantizer, attr):
val = getattr(weight_quantizer, attr)
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
try:
setattr(weight_quantizer, attr, val.cpu())
except (torch.cuda.CudaError, RuntimeError):
pass
return orig_get_weight_scaling_factor(module, weight_name)
quant_utils.get_weight_scaling_factor = patched_get_weight_scaling_factor
print("✓ Patched get_weight_scaling_factor (force weight + quantizer to CPU)")
# ── Patch 5: get_weight_scaling_factor_2 — force quantizer state to CPU ──
orig_get_weight_scaling_factor_2 = quant_utils.get_weight_scaling_factor_2
def patched_get_weight_scaling_factor_2(module, weight_name="weight"):
weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None)
if weight_quantizer is not None:
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
if hasattr(weight_quantizer, attr):
val = getattr(weight_quantizer, attr)
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
try:
setattr(weight_quantizer, attr, val.cpu())
except (torch.cuda.CudaError, RuntimeError):
pass
return orig_get_weight_scaling_factor_2(module, weight_name)
quant_utils.get_weight_scaling_factor_2 = patched_get_weight_scaling_factor_2
print("✓ Patched get_weight_scaling_factor_2 (force quantizer to CPU)")
def snapshot_amax_to_cpu(model, snapshot_path):
"""Walk all quantizers, copy _amax to CPU, save to disk."""