Files
nvfp4-megamoe-kernel/scripts/patch_utils.py

24 lines
673 B
Python
Raw Normal View History

"""Patch vLLM's process_weights_after_loading to call model._post_quant_fix()
after all quant methods have set up their attributes."""
import sys
path = sys.argv[1]
with open(path) as f:
src = f.read()
old = ' if model_config.quantization == "torchao":'
new = ''' # Custom: allow models to run post-quant-init fixes
if hasattr(model, '_post_quant_fix'):
model._post_quant_fix()
if model_config.quantization == "torchao":'''
if old not in src:
print(f"WARNING: Could not find patch target in {path}")
sys.exit(1)
src = src.replace(old, new, 1)
with open(path, 'w') as f:
f.write(src)
print('Patched process_weights_after_loading')