[Quant] [Bugfix] Fix quantization config matching with hf_to_vllm_mapper (#20046)

This commit is contained in:
Kyle Sayers
2025-07-01 06:20:34 -04:00
committed by GitHub
parent c05596f1a3
commit 9025a9a705
17 changed files with 107 additions and 29 deletions

View File

@@ -61,7 +61,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
SupportsMultiModal, SupportsPP, SupportsQuant)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision)
@@ -821,7 +821,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
info=Qwen2_5_VLProcessingInfo,
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
SupportsLoRA, SupportsPP,
SupportsQuant):
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
@@ -846,7 +846,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=self._maybe_ignore_quant_config(self.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
@@ -859,12 +859,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
if isinstance(config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
return config
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor: