[Quantization] Add field to skip unquantized modules for GPTQ config (#25455)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-09-26 23:47:41 +08:00
committed by GitHub
parent db1e42f627
commit d4d9899860
16 changed files with 219 additions and 153 deletions

View File

@@ -5,6 +5,7 @@ from copy import deepcopy
from typing import Any, Callable, Optional, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
@@ -35,6 +36,8 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
RowvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
logger = init_logger(__name__)
@@ -71,10 +74,16 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128,
}
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any]) -> None:
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any],
modules_in_block_to_quantize: Optional[list[str]] = None) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
@@ -121,15 +130,19 @@ class GPTQMarlinConfig(QuantizationConfig):
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround
self.autoround_version = full_config.get("autoround_version", "")
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")
return (
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
@@ -158,8 +171,11 @@ class GPTQMarlinConfig(QuantizationConfig):
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized, dynamic, config)
lm_head_quantized, dynamic, config,
modules_in_block_to_quantize)
@classmethod
def override_quantization_method(
@@ -223,6 +239,35 @@ class GPTQMarlinConfig(QuantizationConfig):
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size)
def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize)
def maybe_update_config(self,
model_name: str,
revision: Optional[str] = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name,
revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get('dtype', None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.