[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
fix_weights(layer, "w13_weight", weight_bits == 4)
|
||||
fix_weights(layer, "w2_weight", weight_bits == 4)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
assert weight_bits == 4 or weight_bits == 8
|
||||
config_builder = (int4_w4a16_moe_quant_config
|
||||
if weight_bits == 4 else int8_w8a16_moe_quant_config)
|
||||
return config_builder(
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
block_shape=[0, group_size],
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
ret = fused_experts(
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
block_shape=[0, group_size])
|
||||
|
||||
return ret
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
||||
|
||||
Reference in New Issue
Block a user