[Quant][Perf] Use moe_wna16 kernel by default for MoEs with many experts (#13236)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-02-14 15:53:42 -05:00
committed by GitHub
parent c9e2d644e7
commit 5e5c8e091e
4 changed files with 39 additions and 26 deletions

View File

@@ -10,20 +10,18 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
@@ -44,15 +42,10 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
) -> None:
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
@@ -90,6 +83,7 @@ class GPTQMarlinConfig(QuantizationConfig):
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.full_config = full_config
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
@@ -132,7 +126,7 @@ class GPTQMarlinConfig(QuantizationConfig):
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized, dynamic)
lm_head_quantized, dynamic, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
@@ -155,12 +149,15 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference")
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod",
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
if layer.num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)