[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm
2025-08-15 14:46:00 -04:00
committed by GitHub
parent 48b01fd4d4
commit 8ad7285ea2
54 changed files with 2010 additions and 1293 deletions

View File

@@ -10,7 +10,8 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig):
if isinstance(layer, LinearBase):
return RTNLinearMethod(self)
elif isinstance(layer, FusedMoE):
return RTNMoEMethod(self)
return RTNMoEMethod(self, layer.moe_config)
return None
@@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase):
class RTNMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: RTNConfig):
def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `RTNMoEMethod` yet.")
@@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size