[Misc][Refactor] Add FusedMoERouter object (#30519)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
UnfusedOAITritonExperts,
|
||||
@@ -891,6 +892,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -898,7 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -992,7 +994,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1119,7 +1121,8 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user