[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
@@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMoEMethod(self)
|
||||
return AWQMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
|
||||
class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: AWQMarlinConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.weight_bits != 4:
|
||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||
@@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `AWQMoEMethod` yet.")
|
||||
@@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
@@ -535,4 +543,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace)
|
||||
workspace=layer.workspace)
|
||||
|
||||
Reference in New Issue
Block a user