[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -10,7 +10,8 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig):
|
||||
if isinstance(layer, LinearBase):
|
||||
return RTNLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return RTNMoEMethod(self)
|
||||
return RTNMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
@@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase):
|
||||
|
||||
class RTNMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: RTNConfig):
|
||||
def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `RTNMoEMethod` yet.")
|
||||
@@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
Reference in New Issue
Block a user