[MoE Refactor] Separate Router into OO Classes (#30623)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-01-18 11:40:49 -05:00
committed by GitHub
parent 2f03035a61
commit 327a02d8db
45 changed files with 1750 additions and 684 deletions

View File

@@ -11,9 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoeWeightScaleSupported,
@@ -23,6 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
@@ -83,13 +83,17 @@ if HAS_TRITON:
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
TritonExperts,
TritonWNA16Experts,
fused_experts,
fused_topk,
get_config_file_name,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
fused_topk,
)
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
GroupedTopk,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)

View File

@@ -117,8 +117,12 @@ class RoutingMethodType(IntEnum):
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# Custom
Custom = (6,)
# Simulated
Simulated = (7,)
# Unspecified
Unspecified = 6.0
Unspecified = 8.0
@dataclass

View File

@@ -13,9 +13,7 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
@@ -34,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
@@ -49,7 +44,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
@@ -1318,375 +1312,6 @@ def try_get_optimal_moe_config(
return config
def vllm_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
def dispatch_topk_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
M, _ = hidden_states.size()
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
):
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(
dynamic=True,
backend=current_platform.simple_compile_backend,
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
)
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if (
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
and current_platform.is_cuda()
and num_expert_group <= 32
and topk <= 32
and e_score_correction_bias is not None
):
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# --8<-- [start:grouped_topk]
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
# --8<-- [end:grouped_topk]
def __init__(
self,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
num_fused_shared_experts: int = 0,
) -> None:
super().__init__()
self.native_impl = grouped_topk
self.topk = topk
self.renormalize = renormalize
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.num_fused_shared_experts = num_fused_shared_experts
def forward_native(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.native_impl(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
def forward_hip(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
return rocm_aiter_grouped_topk(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
self.num_fused_shared_experts,
)
else:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
Returns:
The physical expert ids.
"""
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count = logical_replica_count[topk_ids_long]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
).reshape_as(topk_ids)
# Compute pseudo-random indices by modulo
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
physical_ids = (
logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1)
)
topk_ids = physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
return topk_ids
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "sigmoid":
# Fully fused kernel path for sigmoid
topk_values, topk_indices = ops.grouped_topk(
gating_output, # raw logits
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
1, # scoring_func=1 for sigmoid
)
elif scoring_func == "softmax":
# Apply softmax in Python, then use fused kernel
# TODO: Add support for softmax in kernel
scores = torch.softmax(gating_output, dim=-1)
topk_values, topk_indices = ops.grouped_topk(
scores, # pre-computed scores
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
0, # scoring_func=0 (no activation, scores already computed)
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Fused kernel outputs float32 values and int32 indices directly
return topk_values, topk_indices
def inplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,

View File

@@ -10,13 +10,13 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)

View File

@@ -12,11 +12,13 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
logger = init_logger(__name__)

View File

@@ -21,7 +21,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
@@ -31,14 +31,24 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
@@ -52,31 +62,6 @@ from vllm.utils.torch_utils import (
)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
from .fused_moe import eplb_map_to_physical_and_record
else:
def _eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
logger = init_logger(__name__)
@@ -288,23 +273,6 @@ def maybe_roundup_hidden_size(
return hidden_size
class FusedMoERouterImpl(FusedMoERouter):
def __init__(self, layer: "FusedMoE"):
super().__init__()
self.layer = layer
@property
def routing_method_type(self) -> RoutingMethodType:
return self.layer.routing_method_type
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits)
# --8<-- [start:fused_moe]
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
@@ -440,9 +408,7 @@ class FusedMoE(CustomOp):
self.layer_name = prefix
self.enable_eplb = enable_eplb
self.expert_load_view: torch.Tensor | None = None
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None
self.eplb_state = EplbLayerState()
self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy
)
@@ -538,6 +504,8 @@ class FusedMoE(CustomOp):
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
# TODO(bnell): these attributes are only used by cpu/xpu/mxfp4
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
@@ -547,46 +515,11 @@ class FusedMoE(CustomOp):
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.e_score_correction_bias = e_score_correction_bias
# TODO(bnell): end attributes
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
self._grouped_topk_impl: GroupedTopk | None = None
if self.use_grouped_topk:
assert self.num_expert_group is not None
assert self.topk_group is not None
self._grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
)
# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type: RoutingMethodType = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
self.routing_method_type = RoutingMethodType.DeepSeekV3
elif self.top_k == 1:
self.routing_method_type = RoutingMethodType.Llama4
elif self.scoring_func == "softmax":
self.routing_method_type = (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
else:
self.routing_method_type = RoutingMethodType.TopK
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
@@ -637,8 +570,7 @@ class FusedMoE(CustomOp):
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError(
f"EPLB is not supported {self.quant_method.__class__.__name__}. "
"EPLB is only supported for FP8 quantization for now."
f"EPLB is not supported {self.quant_method.__class__.__name__}."
)
moe_quant_params = {
@@ -663,7 +595,38 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
self.router = FusedMoERouterImpl(self)
# TODO(bnell): in next PR move capture back to layer
capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
eplb_state=self.eplb_state,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
routing_method_type=routing_method_type,
capture=capture,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
@@ -1492,9 +1455,9 @@ class FusedMoE(CustomOp):
This is used later in forward pass, where we get the expert mapping
and record the load metrics in `expert_load_view`.
"""
self.expert_load_view = expert_load_view[moe_layer_idx]
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]
self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx]
self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx]
def ensure_moe_quant_config_init(self):
if self.quant_method.moe_quant_config is None:
@@ -1535,130 +1498,6 @@ class FusedMoE(CustomOp):
device=torch.cuda.current_device(),
)
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
fused_topk_bias,
)
if self.enable_eplb:
if self.quant_method.supports_eplb:
if self.expert_load_view is None:
raise ValueError(
"enable_eplb=True requiere expert_load_view != None"
)
if self.logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if self.logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
else:
raise NotImplementedError(
f"EPLB is not supported for {self.quant_method.method_name}."
)
def valid_grouping() -> bool:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts = router_logits.shape[-1]
if num_experts <= self.num_expert_group:
return False
return num_experts % self.num_expert_group == 0
indices_type = self.quant_method.topk_indices_dtype
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
top_k=self.top_k,
indices_type=indices_type,
)
# DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk and valid_grouping():
assert self._grouped_topk_impl is not None
topk_weights, topk_ids = self._grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias,
)
elif self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
elif self.custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
)
else:
topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.enable_eplb:
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if (indices_type is not None) and topk_ids.dtype != indices_type:
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None: # in dummmy_run may be None
capturer.capture( # noqa
layer_id=self.layer_id,
topk_ids=topk_ids,
)
return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
@@ -1761,8 +1600,12 @@ class FusedMoE(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
assert self.batched_router_logits.dtype == full_router_logits.dtype
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
@@ -2080,15 +1923,8 @@ class FusedMoE(CustomOp):
f"tp_size={self.tp_size},\n"
f"ep_size={self.ep_size}, "
f"reduce_results={self.reduce_results}, "
f"renormalize={self.renormalize}, "
f"use_grouped_topk={self.use_grouped_topk}"
)
if self.use_grouped_topk:
s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
return s

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Callable
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
Returns:
The physical expert ids.
"""
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count = logical_replica_count[topk_ids_long]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
).reshape_as(topk_ids)
# Compute pseudo-random indices by modulo
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
physical_ids = (
logical_to_physical_map[topk_ids_long]
.gather(-1, replica_indices)
.squeeze(-1)
)
topk_ids = physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
return topk_ids
else:
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
class BaseRouter(FusedMoERouter):
"""
Base router class that provides common functionality for all router implementations.
This class implements the template method pattern where select_experts() handles
common pre-processing and post-processing, delegating the actual routing logic
to the abstract _compute_routing() method.
"""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
enable_eplb: bool = False,
# TODO(bnell): Once the MK is constructed at layer init time, we
# can make this a plain value instead of a callback.
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
"""
Note: the indices dtype might not be available at router construction
time, so we need to supply a callback to get it at runtime. This is
because the indices type is supplied by modular kernels which are
created after MoE layer/router construction.
"""
super().__init__()
self.top_k = top_k
self.global_num_experts = global_num_experts
self.eplb_state = eplb_state
self.enable_eplb = enable_eplb
self.indices_type_getter = indices_type_getter
self.capture: Callable[[torch.tensor], None] | None = None
def _validate_eplb_state(self) -> None:
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
if self.enable_eplb:
if self.eplb_state.expert_load_view is None:
raise ValueError("enable_eplb=True requires expert_load_view != None")
if self.eplb_state.logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requires logical_to_physical_map != None"
)
if self.eplb_state.logical_replica_count is None:
raise ValueError(
"enable_eplb=True requires logical_replica_count != None"
)
def _get_indices_type(self) -> torch.dtype | None:
"""Get the desired indices dtype from the getter function."""
return (
self.indices_type_getter() if self.indices_type_getter is not None else None
)
def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
"""Apply EPLB mapping to convert logical expert IDs to physical expert IDs."""
if self.enable_eplb:
assert self.eplb_state.expert_load_view is not None
assert self.eplb_state.logical_to_physical_map is not None
assert self.eplb_state.logical_replica_count is not None
return eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=self.eplb_state.expert_load_view,
logical_to_physical_map=self.eplb_state.logical_to_physical_map,
logical_replica_count=self.eplb_state.logical_replica_count,
)
return topk_ids
def _convert_indices_dtype(
self, topk_ids: torch.Tensor, indices_type: torch.dtype | None
) -> torch.Tensor:
"""Convert topk_ids to the desired dtype if needed."""
if (indices_type is not None) and topk_ids.dtype != indices_type:
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None
return topk_ids
@abstractmethod
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the actual routing logic.
This method must be implemented by subclasses to provide the specific
routing algorithm (e.g., grouped_topk, fused_topk, custom routing, etc.).
Args:
hidden_states: Input hidden states
router_logits: Router logits for expert selection
indices_type: Desired dtype for expert indices (may be None)
Returns:
tuple of (topk_weights, topk_ids)
"""
raise NotImplementedError
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
This method implements the template method pattern:
1. Validates EPLB state
2. Gets indices type
3. Calls _compute_routing() to get topk_weights and topk_ids
4. Applies EPLB mapping if enabled
5. Converts indices dtype if needed
Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids computation result.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
# Step 1: Validate EPLB state
self._validate_eplb_state()
# Step 2: Get indices type.
indices_type = self._get_indices_type()
# Step 3: Compute routing (delegated to subclass)
topk_weights, topk_ids = self._compute_routing(
hidden_states, router_logits, indices_type
)
# Step 4: Apply EPLB mapping
topk_ids = self._apply_eplb_mapping(topk_ids)
# Step 5: Convert indices dtype
topk_ids = self._convert_indices_dtype(topk_ids, indices_type)
# TODO(bnell): temporary hack until select_experts is moved into FusedMoE
if self.capture is not None:
self.capture(topk_ids)
return topk_weights, topk_ids

View File

@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
class CustomRoutingRouter(BaseRouter):
"""Router using a custom user-provided routing function."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
custom_routing_function: Callable,
renormalize: bool = True,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.custom_routing_function = custom_routing_function
self.renormalize = renormalize
@property
def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.Custom
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using the custom routing function."""
topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
)
return topk_weights.to(torch.float32), topk_ids

View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
):
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
class FusedTopKBiasRouter(BaseRouter):
"""Router using fused top-k with e_score_correction_bias."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor,
renormalize: bool = True,
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.e_score_correction_bias = e_score_correction_bias
self.renormalize = renormalize
self.routed_scaling_factor = routed_scaling_factor
@property
def routing_method_type(self) -> RoutingMethodType:
return (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using fused top-k with bias."""
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
return topk_weights, topk_ids

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import vllm._custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
def vllm_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
def dispatch_topk_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
M, _ = hidden_states.size()
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
class FusedTopKRouter(BaseRouter):
"""Default router using standard fused top-k routing."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
scoring_func: str = "softmax",
renormalize: bool = True,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
assert scoring_func == "softmax", "FusedTopKRouter only supports softmax."
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.renormalize = renormalize
@property
def routing_method_type(self) -> RoutingMethodType:
return (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using standard fused top-k."""
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
)
return topk_weights, topk_ids

View File

@@ -0,0 +1,353 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from functools import partial
import torch
from vllm import _custom_ops as ops
from vllm import envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_grouped_topk,
)
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "sigmoid":
# Fully fused kernel path for sigmoid
topk_values, topk_indices = ops.grouped_topk(
gating_output, # raw logits
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
1, # scoring_func=1 for sigmoid
)
elif scoring_func == "softmax":
# Apply softmax in Python, then use fused kernel
# TODO: Add support for softmax in kernel
scores = torch.softmax(gating_output, dim=-1)
topk_values, topk_indices = ops.grouped_topk(
scores, # pre-computed scores
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
0, # scoring_func=0 (no activation, scores already computed)
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Fused kernel outputs float32 values and int32 indices directly
return topk_values, topk_indices
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(
dynamic=True,
backend=current_platform.simple_compile_backend,
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
)
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if (
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
and current_platform.is_cuda()
and num_expert_group <= 32
and topk <= 32
and e_score_correction_bias is not None
):
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# --8<-- [start:grouped_topk]
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
# --8<-- [end:grouped_topk]
def __init__(
self,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
num_fused_shared_experts: int = 0,
) -> None:
super().__init__()
self.native_impl = grouped_topk
self.topk = topk
self.renormalize = renormalize
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.num_fused_shared_experts = num_fused_shared_experts
def forward_native(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.native_impl(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
def forward_hip(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
return rocm_aiter_grouped_topk(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
self.num_fused_shared_experts,
)
else:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
class GroupedTopKRouter(BaseRouter):
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
num_expert_group: int,
topk_group: int,
renormalize: bool = True,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
num_fused_shared_experts: int = 0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
routing_method_type: RoutingMethodType | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.renormalize = renormalize
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.e_score_correction_bias = e_score_correction_bias
self.num_fused_shared_experts = num_fused_shared_experts
# Determine routing method type
if routing_method_type is not None:
self._routing_method_type = routing_method_type
elif scoring_func == "sigmoid":
self._routing_method_type = RoutingMethodType.DeepSeekV3
else:
self._routing_method_type = RoutingMethodType.TopK
@property
def routing_method_type(self) -> RoutingMethodType:
return self._routing_method_type
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using grouped top-k."""
def valid_grouping() -> bool:
# Check if num_experts is greater than num_expert_group
# and is divisible by num_expert_group
num_experts = router_logits.shape[-1]
if num_experts <= self.num_expert_group:
return False
return num_experts % self.num_expert_group == 0
if not valid_grouping():
if self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
else:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
)
return topk_weights, topk_ids
# Select grouped_topk implementation
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
)
return topk_weights, topk_ids

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
CustomRoutingRouter,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
FusedTopKBiasRouter,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
FusedTopKRouter,
)
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
GroupedTopKRouter,
)
from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import (
RoutingSimulatorRouter,
)
EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState()
def create_fused_moe_router(
# common parameters
top_k: int,
global_num_experts: int,
renormalize: bool = True,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
routing_method_type: RoutingMethodType | None = None,
# grouped topk parameters
use_grouped_topk: bool = False,
num_expert_group: int | None = None,
topk_group: int | None = None,
scoring_func: str = "softmax",
num_fused_shared_experts: int = 0,
# grouped topk + fused topk bias parameters
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
# custom routing paramaters
custom_routing_function: Callable | None = None,
# eplb parameters
enable_eplb: bool = False,
eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
capture: Callable[[torch.tensor], None] | None = None,
) -> FusedMoERouter:
"""
Factory function to create the appropriate FusedMoERouter subclass based on
the provided parameters.
The selection logic follows this priority order:
1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
2. GroupedTopKRouter - if use_grouped_topk is True
3. CustomRoutingRouter - if custom_routing_function is not None
4. FusedTopKBiasRouter - if e_score_correction_bias is not None
5. FusedTopKRouter - default fallback
Common arguments:
top_k: Number of experts to select per token
global_num_experts: Total number of experts in the model
renormalize: Whether to renormalize the routing weights
indices_type_getter: Function to get the desired indices dtype
routing_method_type: Optional explicit routing method type
Grouped topk arguments:
use_grouped_topk: Whether to use grouped top-k routing
num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group (for grouped routing)
scoring_func: Scoring function to use ("softmax" or "sigmoid")
num_fused_shared_experts: Number of fused shared experts (for ROCm AITER)
Grouped topk and fused topk bias arguments:
routed_scaling_factor: Scaling factor for routed weights
e_score_correction_bias: Optional bias correction for expert scores
Custom routing arguments:
custom_routing_function: Optional custom routing function
EPLB arguments:
enable_eplb: Whether EPLB is enabled
eplb_state: EPLB (Expert Parallelism Load Balancing) state
Returns:
An instance of the appropriate FusedMoERouter subclass
"""
router: BaseRouter
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
router = RoutingSimulatorRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
# TODO(bnell): this is temporary until select_experts is
# separated from apply.
router.capture = capture
return router
if use_grouped_topk:
assert custom_routing_function is None
if num_expert_group is None or topk_group is None:
raise ValueError(
"num_expert_group and topk_group must be provided when "
"use_grouped_topk is True"
)
router = GroupedTopKRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
num_expert_group=num_expert_group,
topk_group=topk_group,
renormalize=renormalize,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=num_fused_shared_experts,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
routing_method_type=routing_method_type,
)
router.capture = capture
return router
if custom_routing_function is not None:
router = CustomRoutingRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
custom_routing_function=custom_routing_function,
renormalize=renormalize,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router
if scoring_func != "softmax":
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
)
if e_score_correction_bias is not None:
router = FusedTopKBiasRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router
router = FusedTopKRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
renormalize=renormalize,
scoring_func=scoring_func,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router

View File

@@ -1,20 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Token-to-Expert Routing Simulator
This module provides a framework for simulating and testing different
token-to-expert routing strategies for Mixture of Experts (MoE) models.
It supports routing logic customization and includes example implementations
like uniform random routing.
"""
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch
import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
logger = init_logger(__name__)
@@ -308,3 +304,44 @@ class RoutingSimulator:
top_k=top_k,
indices_type=indices_type,
)
class RoutingSimulatorRouter(BaseRouter):
"""Router that uses routing simulation strategies for testing/debugging."""
def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
@property
def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.Simulated
def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Use routing simulator to compute routing."""
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
top_k=self.top_k,
indices_type=indices_type,
)
return topk_weights, topk_ids

View File

@@ -20,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute,
@@ -32,6 +31,9 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel,
select_unquantized_moe_backend,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
@@ -312,9 +314,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
or layer.eplb_state.expert_load_view is not None
or layer.eplb_state.logical_to_physical_map is not None
or layer.eplb_state.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for CPU.")
@@ -346,9 +348,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
or layer.eplb_state.expert_load_view is not None
or layer.eplb_state.logical_to_physical_map is not None
or layer.eplb_state.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion(

View File

@@ -10,12 +10,12 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,

View File

@@ -6,11 +6,11 @@ from typing import Any, Union
import torch
from packaging import version
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,

View File

@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoERouter,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
@@ -40,7 +41,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,

View File

@@ -10,12 +10,12 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (

View File

@@ -23,13 +23,13 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoERouter,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,

View File

@@ -12,11 +12,11 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,

View File

@@ -10,12 +10,12 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,

View File

@@ -8,10 +8,10 @@ from packaging import version
from torch.nn import Module
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
from vllm.model_executor.layers.fused_moe import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,

View File

@@ -13,11 +13,11 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,

View File

@@ -6,12 +6,12 @@ from typing import Any, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEConfig,

View File

@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
@@ -27,7 +28,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
UnfusedOAITritonExperts,
@@ -936,9 +936,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.expert_load_view,
layer.logical_to_physical_map,
layer.logical_replica_count,
layer.eplb_state.expert_load_view,
layer.eplb_state.logical_to_physical_map,
layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration."
if (

View File

@@ -548,7 +548,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -10,12 +10,12 @@ import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,

View File

@@ -201,6 +201,7 @@ class Ernie4_5_MoeMoE(nn.Module):
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@@ -269,6 +269,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
quant_config=quant_config,
e_score_correction_bias=self.e_score_correction_bias[0],
prefix=f"{prefix}.text_experts",
router_logits_dtype=torch.float32,
)
else:
self.text_experts = Ernie4_5_VLMoeMLP(
@@ -306,6 +307,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
quant_config=quant_config,
e_score_correction_bias=self.e_score_correction_bias[1],
prefix=f"{prefix}.vision_experts",
router_logits_dtype=torch.float32,
)
else:
self.vision_experts = Ernie4_5_VLMoeMLP(