From ebd0a17e0e0b5dd326dc6533cf4d8f49806e69e1 Mon Sep 17 00:00:00 2001 From: joninco Date: Fri, 23 Jan 2026 17:19:56 -0500 Subject: [PATCH] [Bugfix] Fix missing is_layer_skipped check for FusedMoE in AWQConfig (#32935) Signed-off-by: jon --- vllm/model_executor/layers/quantization/awq.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ab68c5dca..3cf3116f0 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -106,7 +106,7 @@ class AWQConfig(QuantizationConfig): return AWQLinearMethod(self) elif isinstance(layer, FusedMoE): # Lazy import to avoid circular import. - from .awq_marlin import AWQMarlinConfig, AWQMarlinMoEMethod + from .awq_marlin import AWQMarlinConfig from .moe_wna16 import MoeWNA16Config from .utils.marlin_utils import check_moe_marlin_supports_layer @@ -121,6 +121,7 @@ class AWQConfig(QuantizationConfig): "group_size": self.group_size, "zero_point": self.zero_point, "lm_head": False, + "modules_to_not_convert": self.modules_to_not_convert, } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix @@ -136,7 +137,7 @@ class AWQConfig(QuantizationConfig): awq_marlin_config = AWQMarlinConfig.from_config( marlin_compatible_config_dict ) - return AWQMarlinMoEMethod(awq_marlin_config, layer.moe_config) + return awq_marlin_config.get_quant_method(layer, prefix) return None def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):