Add Patch 4+5: get_weight_scaling_factor and get_weight_scaling_factor_2 CPU safety
Run 10 completed calibration (128/128) but crashed at export in get_weight_scaling_factor — the weight tensor on GPU was stale after 5+ hours of calibration, and weight_scaling_factor_2.to(weight.device) triggered cudaErrorIllegalAddress. Patches 4+5 force weight and quantizer state to CPU before computing scaling factors. This mirrors the same pattern as Patch 3 (get_activation_scaling_factor). Calibrated state saved successfully (721.4 GB, 47,696 amax tensors). Amax snapshot saved (15.4 MB). Re-running with new patches.
This commit is contained in:
@@ -114,6 +114,66 @@ def apply_patches():
|
||||
nvfp4_tensor.NVFP4QTensor.get_activation_scaling_factor = patched_get_activation_scaling_factor
|
||||
print("✓ Patched NVFP4QTensor.get_activation_scaling_factor (CPU + clamp)")
|
||||
|
||||
# ── Patch 4: get_weight_scaling_factor — force weight to CPU before export computation ──
|
||||
# After hours of calibration with use_seq_device_map, model weight tensors on GPU
|
||||
# have stale memory. Any attempt to read them (even .to(device)) triggers
|
||||
# cudaErrorIllegalAddress. We force weight to CPU and do all math on CPU.
|
||||
from modelopt.torch.export import quant_utils
|
||||
from modelopt.torch.quantization.utils import quantizer_attr_names as _quantizer_attr_names
|
||||
|
||||
orig_get_weight_scaling_factor = quant_utils.get_weight_scaling_factor
|
||||
|
||||
def patched_get_weight_scaling_factor(module, weight_name="weight"):
|
||||
"""Force weight and all quantizer state to CPU before computing scaling factors."""
|
||||
# Move weight to CPU if on GPU (avoids stale GPU tensor reads)
|
||||
weight = getattr(module, weight_name)
|
||||
if isinstance(weight, torch.Tensor) and weight.is_cuda:
|
||||
try:
|
||||
weight_cpu = weight.cpu()
|
||||
# Update the module parameter so downstream code uses CPU version
|
||||
with torch.no_grad():
|
||||
setattr(module, weight_name, torch.nn.Parameter(weight_cpu))
|
||||
weight = weight_cpu
|
||||
except (torch.cuda.CudaError, RuntimeError) as e:
|
||||
print(f" WARNING: weight.cpu() failed for {weight_name} ({e}), using zeros")
|
||||
weight = torch.zeros_like(weight, device='cpu')
|
||||
|
||||
# Move quantizer amax to CPU
|
||||
weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None)
|
||||
if weight_quantizer is not None:
|
||||
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
|
||||
if hasattr(weight_quantizer, attr):
|
||||
val = getattr(weight_quantizer, attr)
|
||||
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
|
||||
try:
|
||||
setattr(weight_quantizer, attr, val.cpu())
|
||||
except (torch.cuda.CudaError, RuntimeError):
|
||||
pass
|
||||
|
||||
return orig_get_weight_scaling_factor(module, weight_name)
|
||||
|
||||
quant_utils.get_weight_scaling_factor = patched_get_weight_scaling_factor
|
||||
print("✓ Patched get_weight_scaling_factor (force weight + quantizer to CPU)")
|
||||
|
||||
# ── Patch 5: get_weight_scaling_factor_2 — force quantizer state to CPU ──
|
||||
orig_get_weight_scaling_factor_2 = quant_utils.get_weight_scaling_factor_2
|
||||
|
||||
def patched_get_weight_scaling_factor_2(module, weight_name="weight"):
|
||||
weight_quantizer = getattr(module, _quantizer_attr_names(weight_name).weight_quantizer, None)
|
||||
if weight_quantizer is not None:
|
||||
for attr in ['_amax', '_pre_quant_scale', 'global_amax', '_global_amax']:
|
||||
if hasattr(weight_quantizer, attr):
|
||||
val = getattr(weight_quantizer, attr)
|
||||
if val is not None and isinstance(val, torch.Tensor) and val.is_cuda:
|
||||
try:
|
||||
setattr(weight_quantizer, attr, val.cpu())
|
||||
except (torch.cuda.CudaError, RuntimeError):
|
||||
pass
|
||||
return orig_get_weight_scaling_factor_2(module, weight_name)
|
||||
|
||||
quant_utils.get_weight_scaling_factor_2 = patched_get_weight_scaling_factor_2
|
||||
print("✓ Patched get_weight_scaling_factor_2 (force quantizer to CPU)")
|
||||
|
||||
|
||||
def snapshot_amax_to_cpu(model, snapshot_path):
|
||||
"""Walk all quantizers, copy _amax to CPU, save to disk."""
|
||||
|
||||
Reference in New Issue
Block a user