[MoE][Refactor] Remove most arguments to FusedMoEMethodBase.apply (#29066)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
@@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
@@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
global_scale1=None,
|
||||
global_scale2=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
activation=layer.activation,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
scoring_func,
|
||||
activation,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
layer.use_grouped_topk,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
layer.expert_map,
|
||||
layer.custom_routing_function,
|
||||
layer.e_score_correction_bias,
|
||||
layer.apply_router_weight_on_input,
|
||||
layer.scoring_func,
|
||||
layer.activation,
|
||||
layer.expert_load_view,
|
||||
layer.logical_to_physical_map,
|
||||
layer.logical_replica_count,
|
||||
), "MXFP4 are not supported with this configuration."
|
||||
|
||||
if (
|
||||
@@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
global_num_experts,
|
||||
top_k,
|
||||
layer.global_num_experts,
|
||||
layer.top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
@@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
None,
|
||||
1 if renormalize else 0, # routing_method_type, renormalize
|
||||
1 if layer.renormalize else 0, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
)[0]
|
||||
@@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
topk=layer.top_k,
|
||||
renormalize=layer.renormalize,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
@@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "swigluoai", (
|
||||
assert layer.activation == "swigluoai", (
|
||||
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
|
||||
)
|
||||
hidden_size_pad = round_up(self.original_hidden_size, 128)
|
||||
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
|
||||
hidden_states = layer.ipex_fusion(
|
||||
x_pad,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
activation="swiglu_oai",
|
||||
)
|
||||
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
|
||||
|
||||
Reference in New Issue
Block a user