[MoE][Refactor] Remove most arguments to FusedMoEMethodBase.apply (#29066)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
bnellnm
2025-12-09 16:48:25 -05:00
committed by GitHub
parent 7618dc973d
commit 00e5cbb967
18 changed files with 318 additions and 872 deletions

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum
from typing import Optional
@@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
@@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_scale1=None,
global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
)
assert _can_support_mxfp4(
use_grouped_topk,
topk_group,
num_expert_group,
expert_map,
custom_routing_function,
e_score_correction_bias,
apply_router_weight_on_input,
scoring_func,
activation,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
layer.use_grouped_topk,
layer.topk_group,
layer.num_expert_group,
layer.expert_map,
layer.custom_routing_function,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.expert_load_view,
layer.logical_to_physical_map,
layer.logical_replica_count,
), "MXFP4 are not supported with this configuration."
if (
@@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
global_num_experts,
top_k,
layer.global_num_experts,
layer.top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
@@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.num_experts, # local num experts
None,
None,
1 if renormalize else 0, # routing_method_type, renormalize
1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
@@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
topk=layer.top_k,
renormalize=layer.renormalize,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
@@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
assert activation == "swigluoai", (
assert layer.activation == "swigluoai", (
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
)
hidden_size_pad = round_up(self.original_hidden_size, 128)
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
hidden_states = layer.ipex_fusion(
x_pad,
use_grouped_topk,
top_k,
layer.use_grouped_topk,
layer.top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
activation="swiglu_oai",
)
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()