24 lines
673 B
Python
24 lines
673 B
Python
|
|
"""Patch vLLM's process_weights_after_loading to call model._post_quant_fix()
|
||
|
|
after all quant methods have set up their attributes."""
|
||
|
|
import sys
|
||
|
|
|
||
|
|
path = sys.argv[1]
|
||
|
|
with open(path) as f:
|
||
|
|
src = f.read()
|
||
|
|
|
||
|
|
old = ' if model_config.quantization == "torchao":'
|
||
|
|
new = ''' # Custom: allow models to run post-quant-init fixes
|
||
|
|
if hasattr(model, '_post_quant_fix'):
|
||
|
|
model._post_quant_fix()
|
||
|
|
|
||
|
|
if model_config.quantization == "torchao":'''
|
||
|
|
|
||
|
|
if old not in src:
|
||
|
|
print(f"WARNING: Could not find patch target in {path}")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
src = src.replace(old, new, 1)
|
||
|
|
with open(path, 'w') as f:
|
||
|
|
f.write(src)
|
||
|
|
print('Patched process_weights_after_loading')
|