[Quant] [Bugfix] Fix quantization config matching with hf_to_vllm_mapper (#20046)

This commit is contained in:
Kyle Sayers
2025-07-01 06:20:34 -04:00
committed by GitHub
parent c05596f1a3
commit 9025a9a705
17 changed files with 107 additions and 29 deletions

View File

@@ -24,6 +24,7 @@ from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig,
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
Once the `SupportsQuant` mixin has been added to all models, this
function can be removed
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)
if not issubclass(model_class, SupportsQuant):
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
# pass mappings by reference to quant_config
if hf_to_vllm_mapper is not None:
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if packed_mapping is not None:
quant_config.packed_modules_mapping = packed_mapping