[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase):
fix_weights(layer, "w13_weight", weight_bits == 4)
fix_weights(layer, "w2_weight", weight_bits == 4)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
assert weight_bits == 4 or weight_bits == 8
config_builder = (int4_w4a16_moe_quant_config
if weight_bits == 4 else int8_w8a16_moe_quant_config)
return config_builder(
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, group_size],
)
def apply(
self,
layer: torch.nn.Module,
@@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
ret = fused_experts(
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
@@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
block_shape=[0, group_size])
return ret
quant_config=self.moe_quant_config,
)
def rtn_quantize(tensor: torch.Tensor, num_bits: int,