[Misc][Refactor] Add FusedMoERouter object (#30519)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-01-08 15:52:55 -05:00
committed by GitHub
parent aa125ecf0e
commit e74698c27a
20 changed files with 165 additions and 36 deletions

View File

@@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoeWeightScaleSupported,
@@ -48,6 +51,7 @@ def get_config() -> dict[str, Any] | None:
__all__ = [
"FusedMoE",
"FusedMoERouter",
"FusedMoEConfig",
"FusedMoEMethodBase",
"UnquantizedFusedMoEMethod",

View File

@@ -10,6 +10,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
@@ -109,6 +112,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:

View File

@@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEPrepareAndFinalize,
@@ -88,10 +89,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
class FusedMoERouter(ABC):
"""
FusedMoERouter is an abstract class that provides a 'select_experts'
method that is used for routing hidden states based on router logits.
"""
@property
@abstractmethod
def routing_method_type(self) -> RoutingMethodType:
raise NotImplementedError
@abstractmethod
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids computation result.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
raise NotImplementedError

View File

@@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
@@ -284,6 +285,23 @@ def maybe_roundup_hidden_size(
return hidden_size
class FusedMoERouterImpl(FusedMoERouter):
def __init__(self, layer: "FusedMoE"):
super().__init__()
self.layer = layer
@property
def routing_method_type(self) -> RoutingMethodType:
return self.layer.routing_method_type
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits)
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models.
@@ -339,7 +357,7 @@ class FusedMoE(CustomOp):
is_sequence_parallel=False,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
routing_method_type: RoutingMethodType | None = None,
router_logits_dtype: torch.dtype | None = None,
):
super().__init__()
@@ -529,7 +547,7 @@ class FusedMoE(CustomOp):
# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type = routing_method_type
self.routing_method_type: RoutingMethodType = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
@@ -640,6 +658,8 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
self.router = FusedMoERouterImpl(self)
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
@@ -1503,7 +1523,7 @@ class FusedMoE(CustomOp):
device=torch.cuda.current_device(),
)
def select_experts(
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -1772,6 +1792,7 @@ class FusedMoE(CustomOp):
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
router=self.router,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
@@ -1944,6 +1965,7 @@ class FusedMoE(CustomOp):
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
router=self.router,
x=hidden_states_combined
if do_naive_dispatch_combine
else hidden_states,

View File

@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute,
@@ -285,10 +286,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
router=router,
layer=layer,
x=x,
router_logits=router_logits,
@@ -306,10 +309,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -332,6 +336,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -365,6 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_xpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
@@ -759,12 +760,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -10,7 +10,11 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -495,12 +499,13 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -40,6 +40,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
@@ -458,6 +459,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -484,7 +486,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x_routing,
router_logits=router_logits,
)
@@ -926,10 +928,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1066,12 +1069,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1426,6 +1430,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -1433,7 +1438,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
f"{layer.activation} not supported for Marlin MoE."
)
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1677,12 +1682,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1978,6 +1984,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
@@ -2290,6 +2297,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
):
@@ -2298,7 +2306,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
)
assert self.moe_quant_config is not None
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
@@ -137,12 +138,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
@@ -997,6 +998,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -1051,7 +1053,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -16,7 +16,11 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -629,6 +633,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -639,7 +644,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method."
)
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
@@ -895,12 +896,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -9,6 +9,9 @@ from torch.nn import Module
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -384,6 +387,7 @@ class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
def apply(
self,
layer: torch.nn.Module,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:

View File

@@ -14,8 +14,10 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
@@ -200,7 +202,9 @@ class ModelOptQuantConfigBase(QuantizationConfig):
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
quant_method = self.FusedMoEMethodCls(
quant_config=self, moe_config=layer.moe_config
)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
@@ -720,14 +724,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: ModelOptFp8Config,
layer: FusedMoE,
moe_config: FusedMoEConfig,
) -> None:
super().__init__(layer.moe_config)
super().__init__(moe_config)
self.quant_config = quant_config
assert self.quant_config.is_checkpoint_fp8_serialized
self.fp8_backend = select_fp8_moe_backend(
block_quant=False,
tp_size=layer.moe_parallel_config.tp_size,
tp_size=moe_config.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled,
)
self.kernel: mk.FusedMoEModularKernel | None = None
@@ -935,6 +939,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -961,7 +966,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
topk_weights, topk_ids = layer.select_experts(
# Expert selection
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1325,9 +1331,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def __init__(
self,
quant_config: ModelOptNvFp4Config,
layer: FusedMoE,
moe_config: FusedMoEConfig,
) -> None:
super().__init__(layer.moe_config)
super().__init__(moe_config)
self.quant_config = quant_config
self.nvfp4_backend = select_nvfp4_moe_backend()
# TODO: move this type of check into the oracle.
@@ -1597,6 +1603,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -1621,7 +1628,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x_routing,
router_logits=router_logits,
)

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEConfig,
@@ -364,13 +365,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -27,6 +27,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
UnfusedOAITritonExperts,
@@ -891,6 +892,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -898,7 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -992,7 +994,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1119,7 +1121,8 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:

View File

@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoERouter,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
@@ -350,10 +351,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -542,6 +544,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
@@ -750,10 +753,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -15,7 +15,11 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -356,10 +360,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)