[Quantization] Allow GGUF quantization to skip unquantized layer (#23188)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-08-23 03:04:22 +08:00
committed by GitHub
parent cd7a3df26f
commit 4645024d3a
3 changed files with 36 additions and 3 deletions

View File

@@ -13,7 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
@@ -28,8 +29,10 @@ logger = init_logger(__name__)
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, ) -> None:
def __init__(self,
unquantized_modules: Optional[list[str]] = None) -> None:
super().__init__()
self.unquantized_modules = unquantized_modules or []
def __repr__(self) -> str:
return ("GGUFConfig()")
@@ -55,6 +58,8 @@ class GGUFConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
@@ -63,6 +68,10 @@ class GGUFConfig(QuantizationConfig):
return None
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
return any(module_name in prefix for module_name in unquantized_modules)
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,