diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index 2889bc92d..0e4815be6 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -4,6 +4,7 @@ from fractions import Fraction from typing import TYPE_CHECKING, Any +import regex as re import torch from vllm.logger import init_logger @@ -128,11 +129,44 @@ class AutoRoundConfig(QuantizationConfig): def get_layer_config(self, layer, layer_name: str): def get_config(name: str, quantized: bool = True): - cfg = self.extra_config.get(name, {}) if self.extra_config else {} + if not self.extra_config: + return ( + self.weight_bits if quantized else 16, + self.group_size if quantized else -1, + self.sym if quantized else True, + ) + + # exact match first + if name in self.extra_config: + cfg = self.extra_config[name] + return ( + cfg.get("bits", self.weight_bits if quantized else 16), + cfg.get("group_size", self.group_size if quantized else -1), + cfg.get("sym", self.sym if quantized else True), + ) + + REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\") + for pattern, cfg in self.extra_config.items(): + if not isinstance(pattern, str) or not any( + c in REGEX_SPECIAL_CHARS for c in pattern + ): + continue + + try: + if re.search(re.compile(pattern), name) is not None: + return ( + cfg.get("bits", self.weight_bits if quantized else 16), + cfg.get("group_size", self.group_size if quantized else -1), + cfg.get("sym", self.sym if quantized else True), + ) + except re.error: + # Invalid regex, ignore. + continue + return ( - cfg.get("bits", self.weight_bits if quantized else 16), - cfg.get("group_size", self.group_size if quantized else -1), - cfg.get("sym", self.sym if quantized else True), + self.weight_bits if quantized else 16, + self.group_size if quantized else -1, + self.sym if quantized else True, ) # 1. Exact match from config @@ -176,7 +210,7 @@ class AutoRoundConfig(QuantizationConfig): f"consistent quant config for {sub_names}" ) - # 5. Fallback + # 5. Fallback or try a regular expression match return get_config(layer_name, quantized) def check_quantized(self, weight_bits: int) -> bool: