[Kernels] Isolate modular kernel code from FusedMoEMethodBase subclasses. (#27123)

This commit is contained in:
bnellnm
2025-11-04 08:59:45 -05:00
committed by GitHub
parent e4ee658672
commit 938772af03
16 changed files with 271 additions and 311 deletions

View File

@@ -197,8 +197,6 @@ class Mxfp4Config(QuantizationConfig):
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
@@ -815,6 +813,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"EP batched experts format"
)
else:
layer.w13_weight = (
self.w13_weight_triton_tensor
if layer.w13_weight is None
else layer.w13_weight
)
layer.w2_weight = (
self.w2_weight_triton_tensor
if layer.w2_weight is None
else layer.w2_weight
)
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
assert self.moe_quant_config is not None
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
@@ -838,71 +848,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
)
def _route_and_experts(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
w13_weight = (
self.w13_weight_triton_tensor
if layer.w13_weight is None
else layer.w13_weight
)
w2_weight = (
self.w2_weight_triton_tensor if layer.w2_weight is None else layer.w2_weight
)
assert all([w is not None for w in [w13_weight, w2_weight]])
return self.fused_experts(
hidden_states=x,
w1=w13_weight,
w2=w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@property
def allow_inplace(self) -> bool:
return True
def apply(
self,
@@ -930,29 +878,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.fused_experts is not None:
return self._route_and_experts(
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,