diff --git a/Dockerfile b/Dockerfile index d00ff017..8efdac99 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,23 +40,8 @@ RUN sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM") # Patch process_weights_after_loading to call model._post_quant_fix() after quant setup ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader -RUN python3 -c " -import re -path = '${VLLM_LOADER_DIR}/utils.py'.replace('\$', '') -with open(path) as f: - src = f.read() -# Add _post_quant_fix() call at end of process_weights_after_loading -old = ' if model_config.quantization == \"torchao\":' -new = ''' # Custom: allow models to run post-quant-init fixes - if hasattr(model, '_post_quant_fix'): - model._post_quant_fix() - - if model_config.quantization == \"torchao\":''' -src = src.replace(old, new, 1) -with open(path, 'w') as f: - f.write(src) -print('Patched process_weights_after_loading') -" +COPY scripts/patch_utils.py /tmp/patch_utils.py +RUN python3 /tmp/patch_utils.py ${VLLM_LOADER_DIR}/utils.py && rm /tmp/patch_utils.py # Verify RUN python3 -c "import torch; print(f'PyTorch {torch.__version__} OK')" && \ diff --git a/scripts/patch_utils.py b/scripts/patch_utils.py new file mode 100644 index 00000000..66f90e52 --- /dev/null +++ b/scripts/patch_utils.py @@ -0,0 +1,23 @@ +"""Patch vLLM's process_weights_after_loading to call model._post_quant_fix() +after all quant methods have set up their attributes.""" +import sys + +path = sys.argv[1] +with open(path) as f: + src = f.read() + +old = ' if model_config.quantization == "torchao":' +new = ''' # Custom: allow models to run post-quant-init fixes + if hasattr(model, '_post_quant_fix'): + model._post_quant_fix() + + if model_config.quantization == "torchao":''' + +if old not in src: + print(f"WARNING: Could not find patch target in {path}") + sys.exit(1) + +src = src.replace(old, new, 1) +with open(path, 'w') as f: + f.write(src) +print('Patched process_weights_after_loading')