Clean rewrite: verified all imports against runtime, removed dead code
- get_model/get_tokenizer imported from example_utils (not hf_ptq) - KV_QUANT_CFG_CHOICES imported from hf_ptq (not mtq) - Removed dead _FORCE_AMAX_CPU global and global reference in run_export_only - Fixed stale comments - All 16 imports and references verified against the actual B200 runtime - Zero divergences from modelopt example path except get_model()
This commit is contained in:
@@ -61,18 +61,18 @@ def apply_patches():
|
||||
"""Apply runtime patches for V4 compatibility and GPU tensor safety."""
|
||||
|
||||
from modelopt.torch.quantization.nn.modules import tensor_quantizer as tq_module
|
||||
from modelopt.torch.quantization.qtensor import nvfp4_tensor
|
||||
|
||||
# ── Patch 1: load_calib_amax — force _amax to CPU after calibration ──
|
||||
#
|
||||
# load_calib_amax() is called by max_calibrate() after the forward loop
|
||||
# finishes. It writes _amax to GPU by default. We patch it so _amax
|
||||
# goes to CPU instead, preventing GPU corruption during the long wait
|
||||
# before export.
|
||||
# goes to CPU immediately, preventing GPU corruption during the long
|
||||
# wait before export.
|
||||
orig_load_calib_amax = tq_module.TensorQuantizer.load_calib_amax
|
||||
|
||||
def patched_load_calib_amax(self, *args, **kwargs):
|
||||
orig_load_calib_amax(self, *args, **kwargs)
|
||||
# After _amax is written, move it to CPU
|
||||
if hasattr(self, '_amax') and self._amax is not None:
|
||||
self._amax = self._amax.cpu()
|
||||
|
||||
@@ -92,8 +92,6 @@ def apply_patches():
|
||||
print("✓ Patched TensorQuantizer.export_amax (CPU fallback)")
|
||||
|
||||
# ── Patch 3: NVFP4QTensor.get_activation_scaling_factor — graceful degradation ──
|
||||
from modelopt.torch.quantization.qtensor import nvfp4_tensor
|
||||
|
||||
@classmethod
|
||||
def patched_get_activation_scaling_factor(cls, quantizer):
|
||||
if not quantizer.is_enabled:
|
||||
@@ -111,8 +109,6 @@ def apply_patches():
|
||||
amax = amax.cpu()
|
||||
activation_scaling_factor = amax.float() / (quantizer.maxbound * 448.0)
|
||||
|
||||
# Clamp instead of hard assert — bad values from GPU corruption should
|
||||
# not kill the entire 6-hour run
|
||||
if not torch.all(activation_scaling_factor > 0):
|
||||
n_bad = (activation_scaling_factor <= 0).sum().item()
|
||||
n_total = activation_scaling_factor.numel()
|
||||
@@ -128,13 +124,10 @@ def apply_patches():
|
||||
def snapshot_amax_to_cpu(model, snapshot_path):
|
||||
"""Walk all quantizers, copy their _amax to CPU, save to disk.
|
||||
|
||||
This is the core defensive measure. After calibration completes, the _amax
|
||||
tensors are fresh and valid on GPU. We copy them to CPU immediately and
|
||||
save to disk. This costs almost nothing (~50MB for ~49K quantizers) but
|
||||
guarantees we have valid calibration data even if CUDA corrupts the GPU
|
||||
copies later.
|
||||
|
||||
Returns the snapshot dict: {quantizer_name: amax_tensor_on_cpu}
|
||||
After calibration completes, the _amax tensors are fresh and valid on GPU.
|
||||
We copy them to CPU immediately and save to disk. This costs almost nothing
|
||||
(~50MB for ~49K quantizers) but guarantees we have valid calibration data
|
||||
even if CUDA corrupts the GPU copies later.
|
||||
"""
|
||||
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
|
||||
|
||||
@@ -147,14 +140,11 @@ def snapshot_amax_to_cpu(model, snapshot_path):
|
||||
if not isinstance(module, TensorQuantizer):
|
||||
continue
|
||||
if hasattr(module, '_amax') and module._amax is not None:
|
||||
# Copy to CPU immediately
|
||||
amax_cpu = module._amax.detach().cpu().clone()
|
||||
snapshots[name] = amax_cpu
|
||||
# Replace the GPU copy with the CPU copy
|
||||
module._amax.data.copy_(amax_cpu)
|
||||
n_moved += 1
|
||||
|
||||
# Save snapshots to disk
|
||||
torch.save(snapshots, snapshot_path)
|
||||
size_mb = os.path.getsize(snapshot_path) / (1024**2)
|
||||
print(f"✓ Snapshotted {n_moved} quantizer _amax tensors to CPU ({time.time()-t0:.1f}s)")
|
||||
@@ -164,11 +154,7 @@ def snapshot_amax_to_cpu(model, snapshot_path):
|
||||
|
||||
|
||||
def restore_amax_from_snapshot(model, snapshot_path):
|
||||
"""Restore _amax from a previously saved CPU snapshot.
|
||||
|
||||
Used by --export-only to guarantee valid amax values even if the
|
||||
model state dict has corrupted GPU tensors.
|
||||
"""
|
||||
"""Restore _amax from a previously saved CPU snapshot."""
|
||||
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
|
||||
|
||||
print(f"Restoring _amax from snapshot: {snapshot_path}")
|
||||
@@ -186,11 +172,7 @@ def restore_amax_from_snapshot(model, snapshot_path):
|
||||
|
||||
|
||||
def force_all_amax_to_cpu(model):
|
||||
"""Force ALL quantizer tensors to CPU. Nuclear option after calibration.
|
||||
|
||||
After calling this, no quantizer _amax lives on GPU. Export can't hit
|
||||
CUDA illegal memory access because there's nothing on GPU to corrupt.
|
||||
"""
|
||||
"""Force ALL quantizer tensors to CPU. Nuclear option after calibration."""
|
||||
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
|
||||
|
||||
count = 0
|
||||
@@ -207,18 +189,13 @@ def force_all_amax_to_cpu(model):
|
||||
|
||||
|
||||
def save_calibrated_state(model, path):
|
||||
"""Save model state dict after calibration.
|
||||
|
||||
The insurance policy: if export crashes, we can reload and retry
|
||||
without re-running 6 hours of calibration.
|
||||
"""
|
||||
"""Save model state dict after calibration."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"SAVING CALIBRATED STATE → {path}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
start = time.time()
|
||||
|
||||
# All quantizer state should already be on CPU from snapshot_amax_to_cpu
|
||||
state = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
@@ -234,35 +211,36 @@ def save_calibrated_state(model, path):
|
||||
def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path, calib_size, calib_seq):
|
||||
"""Full pipeline: load → quantize → calibrate → snapshot → save → export."""
|
||||
|
||||
global _FORCE_AMAX_CPU
|
||||
|
||||
os.chdir(EXAMPLE_DIR)
|
||||
sys.path.insert(0, EXAMPLE_DIR)
|
||||
|
||||
os.environ["HF_TOKEN"] = HF_TOKEN
|
||||
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
|
||||
|
||||
# Import from hf_ptq and modelopt — all verified against the example script
|
||||
from example_utils import get_model, get_tokenizer
|
||||
from hf_ptq import (
|
||||
get_model as modelopt_get_model, get_tokenizer, make_calib_dataloader,
|
||||
build_quant_cfg, load_mtp_weights, copy_custom_model_files,
|
||||
QUANT_CFG_CHOICES, KV_QUANT_CFG_CHOICES,
|
||||
make_calib_dataloader,
|
||||
build_quant_cfg,
|
||||
load_mtp_weights,
|
||||
copy_custom_model_files,
|
||||
QUANT_CFG_CHOICES,
|
||||
KV_QUANT_CFG_CHOICES,
|
||||
)
|
||||
from modelopt.torch import quantization as mtq
|
||||
from modelopt.torch.quantization.config import need_calibration
|
||||
from modelopt.torch.utils.dataset_utils import get_max_batch_size
|
||||
from modelopt.torch.export import export_hf_checkpoint
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
apply_patches()
|
||||
|
||||
# ── Load model ──
|
||||
# Use modelopt's get_model() instead of raw AutoModelForCausalLM.from_pretrained.
|
||||
# The raw call OOMs during weight conversion (torch.cat on experts needs 31.5GB,
|
||||
# only 25.9GB free). modelopt's loader handles max_memory/device_map properly.
|
||||
# Use modelopt's get_model() — handles max_memory properly for 3TB model.
|
||||
# Raw AutoModelForCausalLM.from_pretrained OOMs during expert weight conversion.
|
||||
print(f"\nLoading model from {model_path}...")
|
||||
t0 = time.time()
|
||||
|
||||
model = modelopt_get_model(
|
||||
model = get_model(
|
||||
model_path,
|
||||
gpu_mem_percentage=GPU_MEM_PCT,
|
||||
trust_remote_code=True,
|
||||
@@ -272,17 +250,22 @@ def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path,
|
||||
print(f"✓ Model loaded in {time.time()-t0:.0f}s")
|
||||
|
||||
# ── Setup quantization config ──
|
||||
# Same flow as hf_ptq's quantize_main()
|
||||
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[QUANT])
|
||||
quant_cfg = build_quant_cfg(QUANT, quant_cfg, None, None, None)
|
||||
|
||||
if KV_CACHE_QUANT != "none":
|
||||
enable_quant_kv_cache = True
|
||||
print(f"✓ KV cache quantization: {KV_CACHE_QUANT}")
|
||||
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
|
||||
quant_cfg,
|
||||
getattr(mtq, KV_QUANT_CFG_CHOICES[KV_CACHE_QUANT])["quant_cfg"],
|
||||
)
|
||||
print(f"✓ KV cache quantization: {KV_CACHE_QUANT}")
|
||||
else:
|
||||
enable_quant_kv_cache = False
|
||||
|
||||
# ── Detect batch size ──
|
||||
# Same as hf_ptq's quantize_main()
|
||||
print("\nDetecting max calibration batch size...")
|
||||
batch_size = get_max_batch_size(
|
||||
model,
|
||||
@@ -293,6 +276,7 @@ def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path,
|
||||
print(f"✓ Using calibration batch_size={batch_size}")
|
||||
|
||||
# ── Prepare dataloader ──
|
||||
# Same args structure as hf_ptq
|
||||
args = argparse.Namespace(
|
||||
calib_size=[calib_size],
|
||||
calib_seq=calib_seq,
|
||||
@@ -310,23 +294,17 @@ def run_calibration(model_path, export_dir, calib_save_path, amax_snapshot_path,
|
||||
print(f"{'='*60}")
|
||||
t0 = time.time()
|
||||
|
||||
# _FORCE_AMAX_CPU is False during calibration — amax stays on GPU for
|
||||
# fake quantization during the forward passes
|
||||
model = mtq.quantize(model, quant_cfg, forward_loop=calib_dataloader)
|
||||
|
||||
print(f"✓ Quantization + calibration complete in {time.time()-t0:.0f}s")
|
||||
|
||||
# ── IMMEDIATELY snapshot all _amax to CPU ──
|
||||
# This is the critical defensive step. Right after mtq.quantize() returns,
|
||||
# the _amax tensors are fresh and valid on GPU. We copy them to CPU NOW,
|
||||
# before any other GPU operation has a chance to corrupt them.
|
||||
snapshots = snapshot_amax_to_cpu(model, amax_snapshot_path)
|
||||
|
||||
# ── Force ALL quantizer state to CPU ──
|
||||
# After snapshotting, force remaining GPU tensors to CPU too
|
||||
force_all_amax_to_cpu(model)
|
||||
|
||||
# ── Force ALL quantizer state to CPU ──
|
||||
# ── Free GPU memory ──
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -346,12 +324,10 @@ def run_export(model, tokenizer, model_path, export_dir, amax_snapshot_path=None
|
||||
print(f"EXPORTING → {export_dir}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Ensure all quantizer state is on CPU
|
||||
force_all_amax_to_cpu(model)
|
||||
if amax_snapshot_path and os.path.exists(amax_snapshot_path):
|
||||
restore_amax_from_snapshot(model, amax_snapshot_path)
|
||||
|
||||
# Free GPU memory before export
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -383,6 +359,7 @@ def run_export(model, tokenizer, model_path, export_dir, amax_snapshot_path=None
|
||||
|
||||
def run_export_only(calib_save_path, amax_snapshot_path, model_path, export_dir):
|
||||
"""Load saved calibration state and run export only."""
|
||||
|
||||
os.chdir(EXAMPLE_DIR)
|
||||
sys.path.insert(0, EXAMPLE_DIR)
|
||||
|
||||
@@ -391,10 +368,10 @@ def run_export_only(calib_save_path, amax_snapshot_path, model_path, export_dir)
|
||||
|
||||
apply_patches()
|
||||
|
||||
from hf_ptq import get_model as modelopt_get_model, get_tokenizer
|
||||
from example_utils import get_model, get_tokenizer
|
||||
|
||||
print(f"Loading model skeleton from {model_path}...")
|
||||
model = modelopt_get_model(
|
||||
print(f"Loading model from {model_path}...")
|
||||
model = get_model(
|
||||
model_path,
|
||||
device="cpu",
|
||||
trust_remote_code=True,
|
||||
@@ -413,7 +390,6 @@ def run_validate(calib_save_path, amax_snapshot_path):
|
||||
"""Validate saved calibration state — check amax values are valid."""
|
||||
print(f"\nValidating calibration state...")
|
||||
|
||||
# Check amax snapshots
|
||||
if os.path.exists(amax_snapshot_path):
|
||||
snapshots = torch.load(amax_snapshot_path, map_location='cpu')
|
||||
n_total = len(snapshots)
|
||||
@@ -446,7 +422,6 @@ def run_validate(calib_save_path, amax_snapshot_path):
|
||||
else:
|
||||
print(f"✗ No amax snapshot found at {amax_snapshot_path}")
|
||||
|
||||
# Check full state dict
|
||||
if os.path.exists(calib_save_path):
|
||||
size_gb = os.path.getsize(calib_save_path) / (1024**3)
|
||||
print(f"\nCalibrated state: {calib_save_path} ({size_gb:.1f} GB)")
|
||||
|
||||
Reference in New Issue
Block a user