[Bugfix] Allow fallback to AWQ from AWQMarlin at per-layer granularity (#13119)

This commit is contained in:
Michael Goin
2025-02-12 12:19:53 -05:00
committed by GitHub
parent 36a08630e8
commit 09972e716c
4 changed files with 61 additions and 32 deletions

View File

@@ -13,15 +13,17 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
check_marlin_supports_layer, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
@@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig):
8: scalar_types.uint8,
}
def __init__(self,
weight_bits: int,
group_size: int,
zero_point: bool,
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]] = None) -> None:
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
self.full_config = full_config
if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
@@ -96,7 +97,7 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
modules_to_not_convert)
modules_to_not_convert, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
@@ -124,6 +125,13 @@ class AWQMarlinConfig(QuantizationConfig):
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMarlin. "
"Falling back to unoptimized AWQ kernels.")
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)