diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 75ebee6ec..9ac31d2c0 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -32,7 +32,7 @@ th {
| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Subclass |
|---------|--------------------|--------------|---------------|-------|-----------------------|-----------|
-| naive | standard | all1 | G,A,T | N | 6 | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE.forward_impl] |
+| naive | standard | all1 | G,A,T | N | 6 | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE |
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
| deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py
index 893968b5c..6dfcd5ebe 100644
--- a/tests/kernels/moe/modular_kernel_tools/common.py
+++ b/tests/kernels/moe/modular_kernel_tools/common.py
@@ -585,6 +585,7 @@ def make_modular_kernel(
tp_size_=get_tensor_model_parallel_world_size(),
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
+ sp_size_=1,
vllm_parallel_config=vllm_config.parallel_config,
)
@@ -594,6 +595,7 @@ def make_modular_kernel(
hidden_dim=config.K,
intermediate_size_per_partition=config.N,
num_local_experts=config.num_local_experts,
+ num_logical_experts=config.E,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),
diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py
index 897bfddce..984fabc47 100644
--- a/tests/kernels/moe/utils.py
+++ b/tests/kernels/moe/utils.py
@@ -52,6 +52,7 @@ def make_dummy_moe_config(
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition,
num_local_experts=num_experts,
+ num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
in_dtype=in_dtype,
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index b9fee1dd4..6dce6875d 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -913,12 +913,16 @@ class FusedMoEParallelConfig:
pcp_rank: int
dp_rank: int
ep_rank: int
+ sp_size: int
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
- is_sequence_parallel: bool # whether sequence parallelism is used
enable_eplb: bool # whether to enable expert load balancing
+ @property
+ def is_sequence_parallel(self) -> bool:
+ return self.sp_size > 1
+
@property
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@@ -974,6 +978,7 @@ class FusedMoEParallelConfig:
tp_size_: int,
pcp_size_: int,
dp_size_: int,
+ sp_size_: int,
vllm_parallel_config: ParallelConfig,
) -> "FusedMoEParallelConfig":
"""
@@ -1073,9 +1078,9 @@ class FusedMoEParallelConfig:
dp_rank=dp_rank,
ep_size=1,
ep_rank=0,
+ sp_size=sp_size_,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
- is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
# DP + EP / TP + EP / DP + TP + EP
@@ -1093,9 +1098,9 @@ class FusedMoEParallelConfig:
dp_rank=dp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
+ sp_size=sp_size_,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
- is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
@@ -1111,10 +1116,10 @@ class FusedMoEParallelConfig:
dp_rank=0,
ep_size=1,
ep_rank=0,
+ sp_size=1,
use_ep=False,
all2all_backend="naive",
enable_eplb=False,
- is_sequence_parallel=False,
)
@@ -1126,6 +1131,7 @@ class FusedMoEConfig:
hidden_dim: int
intermediate_size_per_partition: int
num_local_experts: int
+ num_logical_experts: int
activation: str
device: torch.device | str
routing_method: RoutingMethodType
@@ -1175,6 +1181,14 @@ class FusedMoEConfig:
def ep_size(self):
return self.moe_parallel_config.ep_size
+ @property
+ def sp_size(self):
+ return self.moe_parallel_config.sp_size
+
+ @property
+ def is_sequence_parallel(self):
+ return self.moe_parallel_config.is_sequence_parallel
+
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
index 93db1c545..ac7c71e52 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
@@ -121,17 +121,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def is_monolithic(self) -> bool:
return False
- # @abstractmethod
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
- # @abstractmethod
def apply_monolithic(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
index 69a6e70fc..1aa9e3a65 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
@@ -89,6 +89,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
@@ -101,5 +102,5 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
- shared_experts_input=layer._get_shared_experts_input(x),
+ shared_experts_input=shared_experts_input,
)
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index f35ec87aa..914dc6846 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -1,13 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Callable, Generator, Iterable
-from contextlib import contextmanager, nullcontext
+from collections.abc import Callable, Iterable
from enum import Enum
from typing import Literal, cast, get_args, overload
import torch
-import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
@@ -16,17 +14,10 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
get_dp_group,
- get_ep_group,
get_pcp_group,
get_tensor_model_parallel_world_size,
- tensor_model_parallel_all_reduce,
)
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
-from vllm.forward_context import (
- ForwardContext,
- get_forward_context,
- is_forward_context_available,
-)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
@@ -47,6 +38,9 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
+from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
+ DefaultMoERunner,
+)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
@@ -57,13 +51,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.platforms import current_platform
-from vllm.utils.math_utils import cdiv, round_up
-from vllm.utils.torch_utils import (
- aux_stream,
- current_stream,
- direct_register_custom_op,
-)
-from vllm.v1.worker.ubatching import dbo_current_ubatch_id
+from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
@@ -264,6 +252,7 @@ def maybe_roundup_hidden_size(
)
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
+
if (
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
@@ -273,6 +262,7 @@ def maybe_roundup_hidden_size(
current_platform.is_rocm()
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
+ or current_mxfp4_backend == Mxfp4Backend.MARLIN
):
hidden_size = round_up(hidden_size, 256)
@@ -338,29 +328,15 @@ class FusedMoE(CustomOp):
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
router_logits_dtype: torch.dtype | None = None,
- has_shared_experts: bool = False,
+ gate: torch.nn.Module | None = None,
+ shared_experts: torch.nn.Module | None = None,
+ routed_input_transform: torch.nn.Module | None = None,
):
super().__init__()
- # Allow disabling of the separate shared experts stream for
- # debug purposes.
- # TODO: Remove this after more extensive testings with TP/DP
- # and other execution modes
- if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
- logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
- self.shared_experts_stream = None
- else:
- # TODO(rob): enable shared expert overlap with non-cuda-alike.
- # aux_stream() returns None on non-cuda-alike platforms.
- self.shared_experts_stream = aux_stream()
- if self.shared_experts_stream is not None:
- logger.debug_once(
- "Enabled separate cuda stream for MoE shared_experts", scope="local"
- )
-
- # For latent MoE: stores original hidden_states before routed_input_transform
- # so shared_experts can use it for cloning (they need original dimension)
- self._shared_experts_input: torch.Tensor | None = None
+ self._gate = gate
+ self._shared_experts = shared_experts
+ self._routed_input_transform = routed_input_transform
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -392,9 +368,12 @@ class FusedMoE(CustomOp):
tp_size_=tp_size_,
pcp_size_=pcp_size_,
dp_size_=dp_size_,
+ sp_size_=self.sp_size,
vllm_parallel_config=vllm_config.parallel_config,
)
+ assert self.moe_parallel_config.is_sequence_parallel == is_sequence_parallel
+
self.global_num_experts = num_experts + num_redundant_experts
self.logical_num_experts = num_experts
@@ -410,6 +389,7 @@ class FusedMoE(CustomOp):
self.layer_name = prefix
self.enable_eplb = enable_eplb
+ # TODO(bnell): should this be owned by router?
self.eplb_state = EplbLayerState()
self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy
@@ -506,7 +486,8 @@ class FusedMoE(CustomOp):
self.reduce_results = reduce_results
self.renormalize = renormalize
- # TODO(bnell): these attributes are only used by cpu/xpu/mxfp4
+ # TODO(bnell): these attributes are only used by monolithic kernels.
+ # Put them in a MoERouterConfig dataclass?
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
@@ -565,6 +546,7 @@ class FusedMoE(CustomOp):
hidden_dim=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
num_local_experts=self.local_num_experts,
+ num_logical_experts=self.logical_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
router_logits_dtype=router_logits_dtype,
@@ -576,9 +558,9 @@ class FusedMoE(CustomOp):
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
# TODO: in_dtype == out_dtype?
- disable_inplace=disable_inplace() or has_shared_experts,
+ disable_inplace=disable_inplace() or self._shared_experts is not None,
)
- if self.use_mori_kernels:
+ if self.moe_config.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, (
"Mori needs to be used with aiter fused_moe for now."
)
@@ -641,9 +623,36 @@ class FusedMoE(CustomOp):
self.quant_method.create_weights(layer=self, **moe_quant_params)
- # Chunked all2all staging tensor
- self.batched_hidden_states: torch.Tensor | None = None
- self.batched_router_logits: torch.Tensor | None = None
+ # Disable shared expert overlap if:
+ # - we are using eplb with non-default backend, because of correctness issues
+ # - we are using flashinfer with DP, since there nothing to gain
+ # - we are using marlin kernels
+ backend = self.moe_parallel_config.all2all_backend
+ self.use_overlapped = (
+ not (
+ (self.enable_eplb and backend != "allgather_reducescatter")
+ or self.moe_parallel_config.use_fi_all2allv_kernels
+ )
+ and self._shared_experts is not None
+ )
+
+ self.runner = self._init_runner()
+
+ def _init_runner(self):
+ # Storing the runner in the FusedMoE is an intermediate state, eventually
+ # the runner will own the FusedMoE layer and provide the execution interface
+ # for MoE ops.
+ return DefaultMoERunner(
+ layer=self,
+ moe_config=self.moe_config,
+ router=self.router,
+ routed_input_transform=self._routed_input_transform,
+ gate=self.gate,
+ shared_experts=self.shared_experts,
+ quant_method=self.quant_method,
+ reduce_results=self.reduce_results,
+ enable_dbo=self.vllm_config.parallel_config.enable_dbo,
+ )
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
@@ -673,10 +682,14 @@ class FusedMoE(CustomOp):
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
)
+ # We need to force reconstruction of runner because we're swapping out
+ # the quant_method with a FusedMoEModularMethod. This logic can go
+ # away once the FusedMoEModularMethod is eliminated.
+ self.runner = self._init_runner()
@property
def shared_experts(self) -> torch.nn.Module | None:
- return None
+ return self._shared_experts if self.use_overlapped else None
@property
def layer_id(self):
@@ -687,53 +700,12 @@ class FusedMoE(CustomOp):
@property
def gate(self) -> torch.nn.Module | None:
- return None
-
- def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Hook to transform hidden_states before passing to routed experts.
- For latent MoE: transforms [S, hidden_size] → [S, moe_latent_size].
- The original hidden_states is saved in _shared_experts_input so
- shared_experts still receive the original [S, hidden_size].
-
- Override in subclasses (e.g., SharedFusedMoE) for latent MoE.
- """
- return hidden_states
-
- @contextmanager
- def _set_shared_experts_input(
- self, value: torch.Tensor | None
- ) -> Generator[None, None, None]:
- """Context manager to safely set/clear _shared_experts_input."""
- self._shared_experts_input = value
- try:
- yield
- finally:
- self._shared_experts_input = None
-
- def _get_shared_experts_input(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Get input for shared experts.
-
- For latent MoE: shared_experts need original [S, hidden_size],
- not the transformed [S, latent_size] used by routed experts.
- """
- return (
- self._shared_experts_input
- if self._shared_experts_input is not None
- else hidden_states
- )
+ return self._gate
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
- @property
- def dp_size(self):
- return self.moe_parallel_config.dp_size
-
- @property
- def pcp_size(self):
- return self.moe_parallel_config.pcp_size
-
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@@ -742,14 +714,6 @@ class FusedMoE(CustomOp):
def tp_rank(self):
return self.moe_parallel_config.tp_rank
- @property
- def dp_rank(self):
- return self.moe_parallel_config.dp_rank
-
- @property
- def pcp_rank(self):
- return self.moe_parallel_config.pcp_rank
-
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@@ -758,39 +722,10 @@ class FusedMoE(CustomOp):
def use_ep(self):
return self.moe_parallel_config.use_ep
- @property
- def use_pplx_kernels(self):
- return self.moe_parallel_config.use_pplx_kernels
-
- @property
- def use_deepep_ht_kernels(self):
- return self.moe_parallel_config.use_deepep_ht_kernels
-
- @property
- def use_deepep_ll_kernels(self):
- return self.moe_parallel_config.use_deepep_ll_kernels
-
- @property
- def use_mori_kernels(self):
- return self.moe_parallel_config.use_mori_kernels
-
- @property
- def use_marlin_kernels(self):
- return getattr(self.quant_method, "use_marlin", False)
-
- @property
- def use_dp_chunking(self) -> bool:
- return (
- self.moe_parallel_config.use_pplx_kernels
- or self.moe_parallel_config.use_deepep_ll_kernels
- or self.moe_parallel_config.use_mori_kernels
- or self.moe_parallel_config.use_fi_all2allv_kernels
- ) and envs.VLLM_ENABLE_MOE_DP_CHUNK
-
@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
- return False
+ return self._gate is not None
def _maybe_init_expert_routing_tables(
self,
@@ -799,7 +734,7 @@ class FusedMoE(CustomOp):
# with DeepEP-ll all2all backend.
if (
self.expert_placement_strategy != "round_robin"
- or not self.use_deepep_ll_kernels
+ or not self.moe_parallel_config.use_deepep_ll_kernels
):
return None
@@ -892,48 +827,6 @@ class FusedMoE(CustomOp):
dp_size=get_dp_group().world_size,
)
- def _maybe_setup_shared_experts_stream(
- self,
- hidden_states: torch.Tensor,
- has_separate_shared_experts: bool,
- use_chunked_impl: bool,
- ) -> tuple[bool, torch.Tensor | None]:
- use_shared_experts_stream = (
- current_platform.is_cuda()
- and has_separate_shared_experts
- and not use_chunked_impl
- and self.shared_experts_stream is not None
- and (
- hidden_states.shape[0]
- <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
- )
- )
-
- hidden_states_clone: torch.Tensor | None = None
- if use_shared_experts_stream:
- assert self.shared_experts_stream is not None
-
- shared_experts_input = self._get_shared_experts_input(hidden_states)
-
- # Clone BEFORE switching streams to avoid race condition
- # where routed_expert kernel may mutate hidden_states.
- hidden_states_clone = shared_experts_input.clone()
-
- # Record that the clone will be used by shared_experts_stream
- # to avoid gc issue from deallocation of hidden_states_clone
- # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
- # NOTE: We don't need shared_output.record_stream(current_stream())
- # because we synch the streams before using shared_output.
- hidden_states_clone.record_stream(self.shared_experts_stream)
-
- # Mark sync start point for the separate shared experts
- # stream here since we want to run in parallel with the
- # router/gate (next op below)
- assert self.shared_experts_stream is not None
- self.shared_experts_stream.wait_stream(current_stream())
-
- return use_shared_experts_stream, hidden_states_clone
-
def _load_per_tensor_weight_scale(
self,
shard_id: str,
@@ -1191,7 +1084,7 @@ class FusedMoE(CustomOp):
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
- if self.quant_method.__class__.__name__ in (
+ if quant_method_name in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
):
@@ -1488,7 +1381,7 @@ class FusedMoE(CustomOp):
assert all(
weight.is_contiguous()
for name, weight in weights
- if not name.startswith("_shared_experts.")
+ if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
)
# Filter out the non-expert weights.
@@ -1538,32 +1431,6 @@ class FusedMoE(CustomOp):
self.ensure_moe_quant_config_init()
return self.quant_method.moe_quant_config
- def ensure_dp_chunking_init(self):
- if not self.use_dp_chunking or self.batched_hidden_states is not None:
- return
-
- states_shape: tuple[int, ...]
- logits_shape: tuple[int, ...]
-
- moe = self.moe_config
-
- if self.vllm_config.parallel_config.enable_dbo:
- states_shape = (2, moe.max_num_tokens, self.hidden_size)
- logits_shape = (2, moe.max_num_tokens, self.logical_num_experts)
- else:
- states_shape = (moe.max_num_tokens, self.hidden_size)
- logits_shape = (moe.max_num_tokens, self.logical_num_experts)
-
- self.batched_hidden_states = torch.zeros(
- states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
- )
-
- self.batched_router_logits = torch.zeros(
- logits_shape,
- dtype=moe.router_logits_dtype,
- device=torch.cuda.current_device(),
- )
-
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
@@ -1577,100 +1444,24 @@ class FusedMoE(CustomOp):
Therefore it is required that we reduce the shared_experts output
early.
"""
- assert self.quant_method is not None
- return (
- isinstance(self.quant_method, FusedMoEModularMethod)
- and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
- )
+ return self.runner.must_reduce_shared_expert_outputs()
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
- if self.must_reduce_shared_expert_outputs():
- return final_hidden_states
- else:
- return tensor_model_parallel_all_reduce(final_hidden_states)
+ return self.runner.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- # For latent MoE: save ORIGINAL hidden_states before transform
- # (shared_experts need original dimension, routed experts use transformed)
- original_hidden_states = hidden_states
- original_hidden_dim = hidden_states.shape[-1]
-
- # Apply transform for routed experts (e.g., latent projection for latent MoE)
- hidden_states = self.apply_routed_input_transform(hidden_states)
-
- # This is the dimension after transform (for routed expert output slicing)
- transformed_hidden_dim = hidden_states.shape[-1]
- if self.hidden_size != transformed_hidden_dim:
- hidden_states = F.pad(
- hidden_states,
- (0, self.hidden_size - transformed_hidden_dim),
- mode="constant",
- value=0.0,
- )
-
- def reduce_output(states: torch.Tensor) -> torch.Tensor:
- if (
- not self.is_sequence_parallel
- and not self.use_dp_chunking
- and self.reduce_results
- and (self.tp_size > 1 or self.ep_size > 1)
- ):
- states = self.maybe_all_reduce_tensor_model_parallel(states)
- return states
-
- def encode_layer_name() -> str:
- # Can be unavailable or None in unittests
- if (
- is_forward_context_available()
- and get_forward_context().all_moe_layers is not None
- ):
- return "from_forward_context"
- return self.layer_name
-
- if self.shared_experts is None:
- if current_platform.is_tpu() or current_platform.is_cpu():
- # TODO: Once the OOM issue for the TPU backend is resolved, we
- # will switch to using the moe_forward custom op.
- # Note: CPU doesn't require wrapped forward_impl.
- fused_output = self.forward_impl(hidden_states, router_logits)
- assert not isinstance(fused_output, tuple)
- else:
- fused_output = torch.ops.vllm.moe_forward(
- hidden_states, router_logits, encode_layer_name()
- )
- return reduce_output(fused_output)[..., :transformed_hidden_dim]
- else:
- if current_platform.is_tpu() or current_platform.is_cpu():
- # TODO: Once the OOM issue for the TPU backend is resolved, we
- # will switch to using the moe_forward custom op.
- # Note: CPU doesn't require wrapped forward_impl.
- with self._set_shared_experts_input(original_hidden_states):
- shared_output, fused_output = self.forward_impl(
- hidden_states, router_logits
- )
- else:
- # Custom op handles setting/clearing _shared_experts_input internally
- # We pass original tensor for shared experts (not transformed)
- shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
- hidden_states,
- router_logits,
- encode_layer_name(),
- original_hidden_states,
- )
-
- # shared_output uses original dimension (before transform)
- # fused_output uses transformed dimension (after transform)
- return (
- reduce_output(shared_output)[..., :original_hidden_dim],
- reduce_output(fused_output)[..., :transformed_hidden_dim],
- )
+ self.ensure_moe_quant_config_init()
+ return self.runner.forward(
+ hidden_states,
+ router_logits,
+ )
@property
def expert_map(self) -> torch.Tensor | None:
@@ -1685,312 +1476,6 @@ class FusedMoE(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits)
- def forward_impl_chunked(
- self,
- full_hidden_states: torch.Tensor,
- full_router_logits: torch.Tensor,
- has_separate_shared_experts: bool,
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert self.batched_hidden_states is not None
- assert self.batched_router_logits is not None
- assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
- f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
- )
- assert self.batched_router_logits.dtype == full_router_logits.dtype, (
- f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
- )
- # Check size compatibility.
- assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
- assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
-
- full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
- if self.shared_experts is not None:
- full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
-
- def process_chunk(chunk_start, chunk_end, skip_result_store=False):
- chunk_size = chunk_end - chunk_start
- hidden_states = full_hidden_states[chunk_start:chunk_end, :]
- router_logits = full_router_logits[chunk_start:chunk_end, :]
-
- assert self.batched_hidden_states is not None
- assert self.batched_router_logits is not None
- # This is only true when DBO has been enabled in the config.
- # Both tensors will have an outer dimension for the ubatch id
- if self.batched_hidden_states.dim() == 3:
- assert self.batched_router_logits.dim() == 3
- batch_buffer_idx = dbo_current_ubatch_id()
- batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
- batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
- else:
- batched_hidden_states = self.batched_hidden_states
- batched_router_logits = self.batched_router_logits
-
- assert (
- batched_hidden_states.size(0) # type: ignore
- >= chunk_size
- )
- assert (
- batched_router_logits.size(0) # type: ignore
- >= chunk_size
- )
- staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
- staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
- staged_hidden_states.copy_(hidden_states, non_blocking=True)
- staged_router_logits.copy_(router_logits, non_blocking=True)
-
- # Matrix multiply.
- if self.quant_method.is_monolithic:
- final_hidden_states = self.quant_method.apply_monolithic(
- layer=self,
- x=staged_hidden_states,
- router_logits=staged_router_logits,
- )
- else:
- topk_weights, topk_ids = self.router.select_experts(
- hidden_states=staged_hidden_states,
- router_logits=staged_router_logits,
- )
-
- final_hidden_states = self.quant_method.apply(
- layer=self,
- x=staged_hidden_states,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- )
-
- if has_separate_shared_experts:
- assert not isinstance(final_hidden_states, tuple)
- assert self.shared_experts is not None
-
- shared_output = self.shared_experts(staged_hidden_states)
-
- final_hidden_states = (
- shared_output,
- final_hidden_states,
- )
-
- if not skip_result_store:
- if self.shared_experts is None:
- full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
- final_hidden_states, non_blocking=True
- )
- else:
- full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
- final_hidden_states[0], non_blocking=True
- )
- full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
- final_hidden_states[1], non_blocking=True
- )
-
- ctx = get_forward_context()
- # flashinfer_cutlass_kernels can handle: optional DP + TP/EP
- max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
- moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
-
- # If the input to the MoE is sequence parallel then divide by sp_size
- # to find the maximum number of tokens for any individual dispatcher.
- if self.is_sequence_parallel:
- max_tokens_across_dispatchers = cdiv(
- max_tokens_across_dispatchers, self.sp_size
- )
-
- num_tokens = full_hidden_states.size(0)
- for chunk_idx, chunk_start_ in enumerate(
- range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
- ):
- chunk_start = chunk_start_
- chunk_end = min(
- chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
- )
- # clamp start and end
- chunk_start = min(chunk_start, num_tokens - 1)
- chunk_end = min(chunk_end, num_tokens)
- with ctx.dp_metadata.chunked_sizes(
- self.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
- ):
- process_chunk(
- chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
- )
-
- if self.shared_experts is None:
- return full_fused_final_hidden_states
- else:
- return (full_shared_final_hidden_states, full_fused_final_hidden_states)
-
- def forward_impl(
- self,
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert self.quant_method is not None
-
- self.ensure_moe_quant_config_init()
- self.ensure_dp_chunking_init()
-
- has_separate_shared_experts = (
- not self.quant_method.mk_owns_shared_expert
- and self.shared_experts is not None
- )
-
- use_chunked_impl = self.use_dp_chunking
-
- use_shared_experts_stream, hidden_states_clone = (
- self._maybe_setup_shared_experts_stream(
- hidden_states, has_separate_shared_experts, use_chunked_impl
- )
- )
-
- # If router/gate provided, then apply it here.
- # (Note: This code runs only when "overlapped mode" is on to allow
- # parallel execution of shared experts with the FusedMoE via
- # separate cuda stream)
- if self.gate is not None:
- router_logits, _ = self.gate(hidden_states)
-
- if use_chunked_impl:
- return self.forward_impl_chunked(
- hidden_states, router_logits, has_separate_shared_experts
- )
-
- # NOTE(rob): once we finish migrating all the quant methods to use
- # MKs, we can remove the naive dispatch/combine path from here.
- do_naive_dispatch_combine = (
- self.dp_size > 1 and not self.quant_method.supports_internal_mk
- )
-
- ctx = get_forward_context()
- sp_ctx = (
- ctx.dp_metadata.sp_local_sizes(self.sp_size)
- if ctx.dp_metadata
- else nullcontext()
- )
-
- with sp_ctx:
- extra_tensors = None
- if do_naive_dispatch_combine:
- post_quant_allgather = (
- self.quant_method is not None
- and self.dp_size > 1
- and self.use_ep
- and getattr(self.quant_method, "do_post_quant_allgather", False)
- )
- if post_quant_allgather:
- hidden_states_to_dispatch, extra_tensors = (
- self.quant_method.prepare_dp_allgather_tensor(
- self, hidden_states, router_logits
- )
- )
- else:
- hidden_states_to_dispatch = hidden_states
-
- dispatch_res = get_ep_group().dispatch_router_logits(
- hidden_states_to_dispatch,
- router_logits,
- self.is_sequence_parallel,
- extra_tensors=extra_tensors,
- )
- if extra_tensors is not None:
- (
- orig_hidden_states,
- router_logits,
- extra_tensors_combined,
- ) = dispatch_res
- hidden_states_combined = (
- orig_hidden_states,
- extra_tensors_combined[0],
- )
- else:
- hidden_states_combined, router_logits = dispatch_res
- orig_hidden_states = hidden_states_combined
- else:
- orig_hidden_states = hidden_states
-
- # Run shared experts before matrix multiply.
- # because matrix multiply maybe modify the hidden_states.
- if has_separate_shared_experts and not use_shared_experts_stream:
- assert self.shared_experts is not None
- shared_input = self._get_shared_experts_input(hidden_states)
- shared_output = self.shared_experts(shared_input)
-
- # NOTE: Similar with DP, PCP also needs dispatch and combine. For
- # simplicity, AgRsAll2All was added separately for PCP here. Maybe
- # we should modify All2AllManager abstract to better support PCP.
- if self.pcp_size > 1:
- hidden_states = get_pcp_group().all_gather(
- hidden_states,
- dim=0,
- )
- router_logits = get_pcp_group().all_gather(
- router_logits,
- dim=0,
- )
-
- # Matrix multiply.
- x = hidden_states_combined if do_naive_dispatch_combine else hidden_states
-
- # TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
- # Figure out nicer way to do this.
- x_orig = orig_hidden_states if do_naive_dispatch_combine else hidden_states
-
- if self.quant_method.is_monolithic:
- final_hidden_states = self.quant_method.apply_monolithic(
- layer=self,
- x=x,
- router_logits=router_logits,
- )
- else:
- topk_weights, topk_ids = self.router.select_experts(
- hidden_states=x_orig,
- router_logits=router_logits,
- )
-
- final_hidden_states = self.quant_method.apply(
- layer=self,
- x=x, # The type signture of this is wrong due to the hack.
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- )
-
- if has_separate_shared_experts:
- assert self.shared_experts is not None
-
- if use_shared_experts_stream:
- # Run shared experts in parallel on a separate stream
- # NOTE: We start the separate stream here and mark the
- # sync end point immediately after it is done. This is
- # important to avoid excessive stream allocations by the cuda
- # graph replay later.
- with torch.cuda.stream(self.shared_experts_stream):
- # Note that hidden_states clone() is necessary here to avoid
- # conflict with the main stream
- shared_output = self.shared_experts(hidden_states_clone)
- current_stream().wait_stream(self.shared_experts_stream)
-
- final_hidden_states = (
- shared_output,
- final_hidden_states,
- )
-
- def combine_output(states: torch.Tensor) -> torch.Tensor:
- if do_naive_dispatch_combine:
- states = get_ep_group().combine(states, self.is_sequence_parallel)
-
- if self.pcp_size > 1:
- states = get_pcp_group().reduce_scatter(
- states,
- dim=0,
- )
-
- return states
-
- if self.shared_experts is not None:
- return (
- final_hidden_states[0],
- combine_output(final_hidden_states[1]),
- )
- else:
- return combine_output(final_hidden_states)
-
@classmethod
def make_expert_params_mapping(
cls,
@@ -2051,94 +1536,6 @@ class FusedMoE(CustomOp):
return s
-def get_layer_from_name(layer_name: str) -> FusedMoE:
- forward_context: ForwardContext = get_forward_context()
- if layer_name == "from_forward_context":
- all_moe_layers = forward_context.all_moe_layers
- assert all_moe_layers is not None
- moe_layer_index = forward_context.moe_layer_index
- if moe_layer_index >= len(all_moe_layers):
- raise AssertionError(
- "We expected the number of MOE layers in `all_moe_layers` "
- "to be equal to the number of "
- "{vllm.moe_forward, vllm.moe_forward_shared} calls."
- )
- layer_name = all_moe_layers[moe_layer_index]
- forward_context.moe_layer_index += 1
- self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
- return self
-
-
-def moe_forward(
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- layer_name: str,
-) -> torch.Tensor:
- self = get_layer_from_name(layer_name)
- assert self.shared_experts is None
- return self.forward_impl(hidden_states, router_logits)
-
-
-def moe_forward_fake(
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- layer_name: str,
-) -> torch.Tensor:
- return torch.empty_like(hidden_states)
-
-
-direct_register_custom_op(
- op_name="moe_forward",
- op_func=moe_forward,
- mutates_args=["hidden_states"],
- fake_impl=moe_forward_fake,
- tags=(torch.Tag.needs_fixed_stride_order,),
-)
-
-
-def moe_forward_shared(
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- layer_name: str,
- shared_experts_input: torch.Tensor | None = None,
-) -> tuple[torch.Tensor, torch.Tensor]:
- self = get_layer_from_name(layer_name)
- assert self.shared_experts is not None
-
- # Set here because torch.compile skips forward_native() setup code
- # and calls this op directly. forward_impl() reads from this var.
- with self._set_shared_experts_input(shared_experts_input):
- return self.forward_impl(hidden_states, router_logits)
-
-
-def moe_forward_shared_fake(
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- layer_name: str,
- shared_experts_input: torch.Tensor | None = None,
-) -> tuple[torch.Tensor, torch.Tensor]:
- # Output shapes:
- # - fused_out: same as hidden_states (routed experts use transformed size)
- # - shared_out: same as shared_experts_input if provided, else same as hidden_states
- # (For latent MoE: shared experts use original hidden_size, not latent size)
- fused_out = torch.empty_like(hidden_states)
-
- if shared_experts_input is not None:
- shared_out = torch.empty_like(shared_experts_input)
- else:
- shared_out = torch.empty_like(hidden_states)
-
- return shared_out, fused_out
-
-
-direct_register_custom_op(
- op_name="moe_forward_shared",
- op_func=moe_forward_shared,
- mutates_args=["hidden_states"],
- fake_impl=moe_forward_shared_fake,
- tags=(torch.Tag.needs_fixed_stride_order,),
-)
-
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
index 8a670216b..e2f77d6c8 100644
--- a/vllm/model_executor/layers/fused_moe/modular_kernel.py
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -1228,7 +1228,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
- shared_experts_input: torch.Tensor | None = None,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
diff --git a/vllm/model_executor/layers/fused_moe/runner/__init__.py b/vllm/model_executor/layers/fused_moe/runner/__init__.py
new file mode 100644
index 000000000..208f01a7c
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/runner/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
new file mode 100644
index 000000000..12b795f30
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
@@ -0,0 +1,743 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from contextlib import nullcontext
+
+import torch
+import torch.nn.functional as F
+
+import vllm.envs as envs
+from vllm.distributed import (
+ get_ep_group,
+ get_pcp_group,
+ tensor_model_parallel_all_reduce,
+)
+from vllm.forward_context import (
+ ForwardContext,
+ get_forward_context,
+ is_forward_context_available,
+)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+)
+from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
+ FusedMoEMethodBase,
+)
+from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
+ FusedMoERouter,
+)
+from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
+from vllm.platforms import current_platform
+from vllm.utils.math_utils import cdiv
+from vllm.utils.torch_utils import (
+ aux_stream,
+ current_stream,
+ direct_register_custom_op,
+)
+from vllm.v1.worker.ubatching import dbo_current_ubatch_id
+
+logger = init_logger(__name__)
+
+
+def get_layer_from_name(layer_name: str) -> torch.nn.Module:
+ forward_context: ForwardContext = get_forward_context()
+ if layer_name == "from_forward_context":
+ all_moe_layers = forward_context.all_moe_layers
+ assert all_moe_layers is not None
+ moe_layer_index = forward_context.moe_layer_index
+ if moe_layer_index >= len(all_moe_layers):
+ raise AssertionError(
+ "We expected the number of MOE layers in `all_moe_layers` "
+ "to be equal to the number of "
+ "{vllm.moe_forward, vllm.moe_forward_shared} calls."
+ )
+ layer_name = all_moe_layers[moe_layer_index]
+ forward_context.moe_layer_index += 1
+ return forward_context.no_compile_layers[layer_name]
+
+
+def _moe_forward(
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
+ layer_name: str,
+) -> torch.Tensor:
+ layer = get_layer_from_name(layer_name)
+ return layer.runner.forward_impl(
+ layer, hidden_states, router_logits, shared_experts_input
+ )
+
+
+def _moe_forward_fake(
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
+ layer_name: str,
+) -> torch.Tensor:
+ return torch.empty_like(hidden_states)
+
+
+def _moe_forward_shared(
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
+ layer_name: str,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ layer = get_layer_from_name(layer_name)
+ return layer.runner.forward_impl(
+ layer, hidden_states, router_logits, shared_experts_input
+ )
+
+
+def _moe_forward_shared_fake(
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
+ layer_name: str,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ # Output shapes:
+ # - fused_out: same as hidden_states (routed experts use transformed size)
+ # - shared_out: same as shared_experts_input if provided, else same as
+ # hidden_states
+ # (For latent MoE: shared experts use original hidden_size, not latent size)
+ fused_out = torch.empty_like(hidden_states)
+
+ if shared_experts_input is not None:
+ shared_out = torch.empty_like(shared_experts_input)
+ else:
+ shared_out = torch.empty_like(hidden_states)
+
+ return shared_out, fused_out
+
+
+direct_register_custom_op(
+ op_name="moe_forward",
+ op_func=_moe_forward,
+ mutates_args=["hidden_states"],
+ fake_impl=_moe_forward_fake,
+ tags=(torch.Tag.needs_fixed_stride_order,),
+)
+
+
+direct_register_custom_op(
+ op_name="moe_forward_shared",
+ op_func=_moe_forward_shared,
+ mutates_args=["hidden_states"],
+ fake_impl=_moe_forward_shared_fake,
+ tags=(torch.Tag.needs_fixed_stride_order,),
+)
+
+
+class DefaultMoERunner(MoERunner):
+ """
+ Default implementation of the MoE runner for executing Mixture of Experts layers.
+
+ This class provides a comprehensive implementation for running MoE computations
+ with support for:
+ - Expert routing and token dispatching
+ - Shared experts computation with optional parallel execution using CUDA streams
+ - Data parallel (DP) chunking for large batch processing
+ - Tensor model parallel and expert parallel operations
+ - Various quantization methods and custom operators
+ - Both monolithic and decomposed expert execution paths
+
+ The runner handles the complete MoE forward pass including routing tokens to
+ experts, executing expert computations, and combining results. It supports
+ advanced features like overlapped execution of shared experts and optimized
+ kernels for different parallel execution modes.
+
+ Eventually, this class will be split up and specialized for different
+ configurations, e.g. the presense or absence of shared experts, a gate, etc.
+ """
+
+ def __init__(
+ self,
+ layer: torch.nn.Module,
+ moe_config: FusedMoEConfig,
+ router: FusedMoERouter,
+ routed_input_transform: torch.nn.Module | None,
+ gate: torch.nn.Module | None,
+ shared_experts: torch.nn.Module | None,
+ quant_method: FusedMoEMethodBase,
+ reduce_results: bool,
+ enable_dbo: bool,
+ ):
+ super().__init__()
+ self.moe_config = moe_config
+ self.router = router
+ self.routed_input_transform = routed_input_transform
+ self.gate = gate
+ self.shared_experts = shared_experts
+ self.quant_method = quant_method
+ self.reduce_results = reduce_results
+ self.enable_dbo = enable_dbo
+
+ # Allow disabling of the separate shared experts stream for
+ # debug purposes.
+ # TODO: Remove this after more extensive testings with TP/DP
+ # and other execution modes
+ if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
+ logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
+ self.shared_experts_stream = None
+ else:
+ # TODO(rob): enable shared expert overlap with non-cuda-alike.
+ # aux_stream() returns None on non-cuda-alike platforms.
+ self.shared_experts_stream = aux_stream()
+ if self.shared_experts_stream is not None:
+ logger.debug_once(
+ "Enabled separate cuda stream for MoE shared_experts", scope="local"
+ )
+
+ # Needed for string -> FusedMoE layer lookup in custom ops.
+ self.layer_name = layer.layer_name
+
+ if current_platform.is_tpu() or current_platform.is_cpu():
+ # TODO: Once the OOM issue for the TPU backend is resolved, we
+ # will switch to using the moe_forward custom op.
+ # Note: CPU doesn't require wrapped forward_impl.
+ if self.shared_experts is None:
+ self.moe_forward = _moe_forward
+ else:
+ self.moe_forward = _moe_forward_shared
+ else:
+ if self.shared_experts is None:
+ self.moe_forward = torch.ops.vllm.moe_forward
+ else:
+ self.moe_forward = torch.ops.vllm.moe_forward_shared
+
+ # Chunked all2all staging tensor
+ self.batched_hidden_states: torch.Tensor | None = None
+ self.batched_router_logits: torch.Tensor | None = None
+
+ @property
+ def use_dp_chunking(self) -> bool:
+ return (
+ self.moe_config.moe_parallel_config.use_pplx_kernels
+ or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
+ or self.moe_config.moe_parallel_config.use_mori_kernels
+ or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
+ ) and envs.VLLM_ENABLE_MOE_DP_CHUNK
+
+ def _maybe_setup_shared_experts_stream(
+ self,
+ hidden_states: torch.Tensor,
+ shared_input: torch.Tensor | None,
+ has_separate_shared_experts: bool,
+ use_chunked_impl: bool,
+ ) -> tuple[bool, torch.Tensor | None]:
+ use_shared_experts_stream = (
+ current_platform.is_cuda()
+ and has_separate_shared_experts
+ and not use_chunked_impl
+ and self.shared_experts_stream is not None
+ and (
+ hidden_states.shape[0]
+ <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
+ )
+ )
+
+ hidden_states_clone: torch.Tensor | None = None
+ if use_shared_experts_stream:
+ assert self.shared_experts_stream is not None
+
+ shared_experts_input = (
+ shared_input if shared_input is not None else hidden_states
+ )
+
+ # Clone BEFORE switching streams to avoid race condition
+ # where routed_expert kernel may mutate hidden_states.
+ hidden_states_clone = shared_experts_input.clone()
+
+ # Record that the clone will be used by shared_experts_stream
+ # to avoid gc issue from deallocation of hidden_states_clone
+ # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
+ # NOTE: We don't need shared_output.record_stream(current_stream())
+ # because we synch the streams before using shared_output.
+ hidden_states_clone.record_stream(self.shared_experts_stream)
+
+ # Mark sync start point for the separate shared experts
+ # stream here since we want to run in parallel with the
+ # router/gate (next op below)
+ assert self.shared_experts_stream is not None
+ self.shared_experts_stream.wait_stream(current_stream())
+
+ return use_shared_experts_stream, hidden_states_clone
+
+ def ensure_dp_chunking_init(self):
+ if not self.use_dp_chunking or self.batched_hidden_states is not None:
+ return
+
+ states_shape: tuple[int, ...]
+ logits_shape: tuple[int, ...]
+
+ moe = self.moe_config
+
+ if self.enable_dbo:
+ states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
+ logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
+ else:
+ states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
+ logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
+
+ self.batched_hidden_states = torch.zeros(
+ states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
+ )
+
+ self.batched_router_logits = torch.zeros(
+ logits_shape,
+ dtype=moe.router_logits_dtype,
+ device=torch.cuda.current_device(),
+ )
+
+ def must_reduce_shared_expert_outputs(self) -> bool:
+ """
+ The shared_experts are typically computed using the RowParallelLinear
+ layer. The result of this function is typically used as
+ the reduce_results argument to the module.
+ When just tensor-parallel is used, it is not required to reduce
+ the shared_experts results immediately. Instead we reduce at the
+ once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
+ With EP and all2all kernels - this is no longer viable as all
+ GPU ranks in DP, produce the complete set of hidden_states.
+ Therefore it is required that we reduce the shared_experts output
+ early.
+ """
+ assert self.quant_method is not None
+ return (
+ self.quant_method.moe_mk is not None
+ and self.quant_method.moe_mk.output_is_reduced()
+ )
+
+ def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
+ """
+ Some combine kernels reduce across GPU ranks by default.
+ """
+ if self.must_reduce_shared_expert_outputs():
+ return final_hidden_states
+ else:
+ return tensor_model_parallel_all_reduce(final_hidden_states)
+
+ def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Apply transform for routed experts (e.g., latent projection).
+
+ This is called by FusedMoE.forward_native. The original hidden_states
+ is saved separately so shared experts get [S, hidden_size] while
+ routed experts get the transformed [S, moe_latent_size].
+
+ TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
+ moved inside SharedFusedMoE to all-reduce on the smaller latent
+ dimension.
+ """
+ if self.routed_input_transform is not None:
+ result = self.routed_input_transform(hidden_states)
+ # ReplicatedLinear returns (output, extra_bias) tuple.
+ # We only need the output tensor; extra_bias is not used here.
+ if isinstance(result, tuple):
+ return result[0]
+ return result
+ return hidden_states
+
+ def _reduce_output(
+ self,
+ states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
+ trunc_sizes: list[int],
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
+ return x[..., :trunc_size]
+
+ def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
+ return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)
+
+ if (
+ not self.moe_config.is_sequence_parallel
+ and not self.use_dp_chunking
+ and self.reduce_results
+ and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
+ ):
+ func = reduce_and_trunc
+ else:
+ func = trunc
+
+ if isinstance(states, tuple):
+ return tuple(
+ [func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
+ )
+ else:
+ assert len(trunc_sizes) == 1
+ return func(states, trunc_sizes[0])
+
+ def _encode_layer_name(self) -> str:
+ # Can be unavailable or None in unittests
+ if (
+ is_forward_context_available()
+ and get_forward_context().all_moe_layers is not None
+ ):
+ return "from_forward_context"
+ return self.layer_name
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ # For latent MoE: save ORIGINAL hidden_states before transform
+ # (shared_experts need original dimension, routed experts use transformed)
+ original_hidden_states = hidden_states
+ original_hidden_dim = hidden_states.shape[-1]
+
+ # Apply transform for routed experts (e.g., latent projection for latent MoE)
+ hidden_states = self.apply_routed_input_transform(hidden_states)
+
+ # This is the dimension after transform (for routed expert output slicing)
+ transformed_hidden_dim = hidden_states.shape[-1]
+ if self.moe_config.hidden_dim != transformed_hidden_dim:
+ hidden_states = F.pad(
+ hidden_states,
+ (0, self.moe_config.hidden_dim - transformed_hidden_dim),
+ mode="constant",
+ value=0.0,
+ )
+
+ fused_output = self.moe_forward(
+ hidden_states,
+ router_logits,
+ original_hidden_states,
+ self._encode_layer_name(),
+ )
+
+ if isinstance(fused_output, tuple):
+ orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
+ else:
+ orig_hidden_dims = [transformed_hidden_dim]
+
+ return self._reduce_output(fused_output, orig_hidden_dims)
+
+ def forward_impl_chunked(
+ self,
+ layer: torch.nn.Module,
+ full_hidden_states: torch.Tensor,
+ full_router_logits: torch.Tensor,
+ shared_input: torch.Tensor | None,
+ has_separate_shared_experts: bool,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ assert self.batched_hidden_states is not None
+ assert self.batched_router_logits is not None
+ assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
+ f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
+ )
+ assert self.batched_router_logits.dtype == full_router_logits.dtype, (
+ f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
+ )
+ # Check size compatibility.
+ assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
+ assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
+
+ # TODO(bnell): Fix shared_expert_inputs w/chunking.
+ # assert shared_input is None, (
+ # "Routed input transform is not currently supported with DP chunking."
+ # )
+
+ full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
+ if self.shared_experts is not None:
+ full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
+
+ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
+ chunk_size = chunk_end - chunk_start
+ hidden_states = full_hidden_states[chunk_start:chunk_end, :]
+ router_logits = full_router_logits[chunk_start:chunk_end, :]
+
+ assert self.batched_hidden_states is not None
+ assert self.batched_router_logits is not None
+ # This is only true when DBO has been enabled in the config.
+ # Both tensors will have an outer dimension for the ubatch id
+ if self.batched_hidden_states.dim() == 3:
+ assert self.batched_router_logits.dim() == 3
+ batch_buffer_idx = dbo_current_ubatch_id()
+ batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
+ batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
+ else:
+ batched_hidden_states = self.batched_hidden_states
+ batched_router_logits = self.batched_router_logits
+
+ assert (
+ batched_hidden_states.size(0) # type: ignore
+ >= chunk_size
+ )
+ assert (
+ batched_router_logits.size(0) # type: ignore
+ >= chunk_size
+ )
+ staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
+ staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
+ staged_hidden_states.copy_(hidden_states, non_blocking=True)
+ staged_router_logits.copy_(router_logits, non_blocking=True)
+
+ # Matrix multiply.
+ if self.quant_method.is_monolithic:
+ final_hidden_states = self.quant_method.apply_monolithic(
+ layer=layer,
+ x=staged_hidden_states,
+ router_logits=staged_router_logits,
+ )
+ else:
+ topk_weights, topk_ids = self.router.select_experts(
+ hidden_states=staged_hidden_states,
+ router_logits=staged_router_logits,
+ )
+
+ final_hidden_states = self.quant_method.apply(
+ layer=layer,
+ x=staged_hidden_states,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ shared_experts_input=shared_input,
+ )
+
+ if has_separate_shared_experts:
+ assert not isinstance(final_hidden_states, tuple)
+ assert self.shared_experts is not None
+
+ shared_output = self.shared_experts(staged_hidden_states)
+
+ final_hidden_states = (
+ shared_output,
+ final_hidden_states,
+ )
+
+ if not skip_result_store:
+ if self.shared_experts is None:
+ full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
+ final_hidden_states, non_blocking=True
+ )
+ else:
+ full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
+ final_hidden_states[0], non_blocking=True
+ )
+ full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
+ final_hidden_states[1], non_blocking=True
+ )
+
+ ctx = get_forward_context()
+ # flashinfer_cutlass_kernels can handle: optional DP + TP/EP
+ max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
+ moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
+
+ # If the input to the MoE is sequence parallel then divide by sp_size
+ # to find the maximum number of tokens for any individual dispatcher.
+ if self.moe_config.is_sequence_parallel:
+ max_tokens_across_dispatchers = cdiv(
+ max_tokens_across_dispatchers, self.moe_config.sp_size
+ )
+
+ num_tokens = full_hidden_states.size(0)
+ for chunk_idx, chunk_start_ in enumerate(
+ range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
+ ):
+ chunk_start = chunk_start_
+ chunk_end = min(
+ chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
+ )
+ # clamp start and end
+ chunk_start = min(chunk_start, num_tokens - 1)
+ chunk_end = min(chunk_end, num_tokens)
+ with ctx.dp_metadata.chunked_sizes(
+ self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
+ ):
+ process_chunk(
+ chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
+ )
+
+ if self.shared_experts is None:
+ return full_fused_final_hidden_states
+ else:
+ return (full_shared_final_hidden_states, full_fused_final_hidden_states)
+
+ def forward_impl(
+ self,
+ layer: torch.nn.Module,
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ shared_input: torch.Tensor | None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ assert self.quant_method is not None
+
+ self.ensure_dp_chunking_init()
+
+ has_separate_shared_experts = (
+ not self.quant_method.mk_owns_shared_expert
+ and self.shared_experts is not None
+ )
+
+ use_chunked_impl = self.use_dp_chunking
+
+ use_shared_experts_stream, hidden_states_clone = (
+ self._maybe_setup_shared_experts_stream(
+ hidden_states,
+ shared_input,
+ has_separate_shared_experts,
+ use_chunked_impl,
+ )
+ )
+
+ # If router/gate provided, then apply it here.
+ # (Note: This code runs only when "overlapped mode" is on to allow
+ # parallel execution of shared experts with the FusedMoE via
+ # separate cuda stream)
+ if self.gate is not None:
+ router_logits, _ = self.gate(hidden_states)
+
+ if use_chunked_impl:
+ return self.forward_impl_chunked(
+ layer,
+ hidden_states,
+ router_logits,
+ shared_input,
+ has_separate_shared_experts,
+ )
+
+ # NOTE(rob): once we finish migrating all the quant methods to use
+ # MKs, we can remove the naive dispatch/combine path from here.
+ do_naive_dispatch_combine = (
+ self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
+ )
+
+ ctx = get_forward_context()
+ sp_ctx = (
+ ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
+ if ctx.dp_metadata
+ else nullcontext()
+ )
+
+ with sp_ctx:
+ extra_tensors = None
+ if do_naive_dispatch_combine:
+ post_quant_allgather = (
+ self.quant_method is not None
+ and self.moe_config.dp_size > 1
+ and self.moe_config.use_ep
+ and getattr(self.quant_method, "do_post_quant_allgather", False)
+ )
+ if post_quant_allgather:
+ hidden_states_to_dispatch, extra_tensors = (
+ self.quant_method.prepare_dp_allgather_tensor(
+ layer, hidden_states, router_logits
+ )
+ )
+ else:
+ hidden_states_to_dispatch = hidden_states
+
+ dispatch_res = get_ep_group().dispatch_router_logits(
+ hidden_states_to_dispatch,
+ router_logits,
+ self.moe_config.is_sequence_parallel,
+ extra_tensors=extra_tensors,
+ )
+ if extra_tensors is not None:
+ (
+ orig_hidden_states,
+ router_logits,
+ extra_tensors_combined,
+ ) = dispatch_res
+ hidden_states_combined = (
+ orig_hidden_states,
+ extra_tensors_combined[0],
+ )
+ else:
+ hidden_states_combined, router_logits = dispatch_res
+ orig_hidden_states = hidden_states_combined
+ else:
+ orig_hidden_states = hidden_states
+
+ # Run shared experts before matrix multiply.
+ # because matrix multiply maybe modify the hidden_states.
+ if has_separate_shared_experts and not use_shared_experts_stream:
+ assert self.shared_experts is not None
+ shared_input = (
+ shared_input if shared_input is not None else hidden_states
+ )
+ shared_output = self.shared_experts(shared_input)
+
+ # NOTE: Similar with DP, PCP also needs dispatch and combine. For
+ # simplicity, AgRsAll2All was added separately for PCP here. Maybe
+ # we should modify All2AllManager abstract to better support PCP.
+ if self.moe_config.pcp_size > 1:
+ hidden_states = get_pcp_group().all_gather(
+ hidden_states,
+ dim=0,
+ )
+ router_logits = get_pcp_group().all_gather(
+ router_logits,
+ dim=0,
+ )
+
+ # TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
+ # Figure out nicer way to do this.
+ if do_naive_dispatch_combine:
+ x = hidden_states_combined
+ x_orig = orig_hidden_states
+ else:
+ x = hidden_states
+ x_orig = hidden_states
+
+ # Matrix multiply.
+ if self.quant_method.is_monolithic:
+ final_hidden_states = self.quant_method.apply_monolithic(
+ layer=layer,
+ x=x,
+ router_logits=router_logits,
+ )
+ else:
+ topk_weights, topk_ids = self.router.select_experts(
+ hidden_states=x_orig,
+ router_logits=router_logits,
+ )
+
+ final_hidden_states = self.quant_method.apply(
+ layer=layer,
+ x=x, # The type signture of this is wrong due to the hack.
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ shared_experts_input=shared_input,
+ )
+
+ if has_separate_shared_experts:
+ assert self.shared_experts is not None
+
+ if use_shared_experts_stream:
+ # Run shared experts in parallel on a separate stream
+ # NOTE: We start the separate stream here and mark the
+ # sync end point immediately after it is done. This is
+ # important to avoid excessive stream allocations by the cuda
+ # graph replay later.
+ with torch.cuda.stream(self.shared_experts_stream):
+ # Note that hidden_states clone() is necessary here to avoid
+ # conflict with the main stream
+ shared_output = self.shared_experts(hidden_states_clone)
+ current_stream().wait_stream(self.shared_experts_stream)
+
+ final_hidden_states = (
+ shared_output,
+ final_hidden_states,
+ )
+
+ def combine_output(states: torch.Tensor) -> torch.Tensor:
+ if do_naive_dispatch_combine:
+ states = get_ep_group().combine(
+ states, self.moe_config.is_sequence_parallel
+ )
+
+ if self.moe_config.pcp_size > 1:
+ states = get_pcp_group().reduce_scatter(
+ states,
+ dim=0,
+ )
+
+ return states
+
+ if self.shared_experts is not None:
+ return (
+ final_hidden_states[0],
+ combine_output(final_hidden_states[1]),
+ )
+ else:
+ return combine_output(final_hidden_states)
diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py
new file mode 100644
index 000000000..b298cc2d0
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py
@@ -0,0 +1,34 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from abc import ABC, abstractmethod
+
+import torch
+
+
+class MoERunner(ABC):
+ """
+ Abstract base class for Mixture of Experts (MoE) runners.
+
+ This class defines the interface that all MoE runner implementations must follow.
+ MoE runners are responsible for executing the forward pass of MoE layers, handling
+ expert routing, and managing tensor parallel operations.
+ """
+
+ @abstractmethod
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def must_reduce_shared_expert_outputs(self) -> bool:
+ raise NotImplementedError
+
+ @abstractmethod
+ def maybe_all_reduce_tensor_model_parallel(
+ self,
+ final_hidden_states: torch.Tensor,
+ ):
+ raise NotImplementedError
diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
index 937d13d34..37336df17 100644
--- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
@@ -18,70 +18,6 @@ class SharedFusedMoE(FusedMoE):
can be interleaved with the fused all2all dispatch communication step.
"""
- def __init__(
- self,
- shared_experts: torch.nn.Module | None,
- gate: torch.nn.Module | None = None,
- use_overlapped: bool = True,
- routed_input_transform: torch.nn.Module | None = None,
- **kwargs,
- ):
- # Pass has_shared_experts so FusedMoE.__init__ can set disable_inplace
- # without accessing self.shared_experts (submodules cannot be set before
- # Module.__init__()).
- kwargs["has_shared_experts"] = shared_experts is not None
- super().__init__(**kwargs)
- self._shared_experts = shared_experts
- self._routed_input_transform = routed_input_transform
-
- # Disable shared expert overlap if:
- # - we are using eplb with non-default backend, because of correctness issues
- # - we are using flashinfer with DP, since there nothing to gain
- # - we are using marlin kernels
- backend = self.moe_parallel_config.all2all_backend
- self.use_overlapped = (
- use_overlapped
- and not (
- (self.enable_eplb and backend != "allgather_reducescatter")
- or self.moe_parallel_config.use_fi_all2allv_kernels
- )
- and self._shared_experts is not None
- )
-
- self._gate = gate
-
- @property
- def shared_experts(self) -> torch.nn.Module | None:
- return self._shared_experts if self.use_overlapped else None
-
- @property
- def gate(self) -> torch.nn.Module | None:
- return self._gate if self.use_overlapped else None
-
- @property
- def is_internal_router(self) -> bool:
- return self.gate is not None
-
- def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Apply transform for routed experts (e.g., latent projection).
-
- This is called by FusedMoE.forward_native. The original hidden_states
- is saved separately so shared experts get [S, hidden_size] while
- routed experts get the transformed [S, moe_latent_size].
-
- TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
- moved inside SharedFusedMoE to all-reduce on the smaller latent
- dimension.
- """
- if self._routed_input_transform is not None:
- result = self._routed_input_transform(hidden_states)
- # ReplicatedLinear returns (output, extra_bias) tuple.
- # We only need the output tensor; extra_bias is not used here.
- if isinstance(result, tuple):
- return result[0]
- return result
- return hidden_states
-
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index 8a35be78b..5c86064a9 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -55,6 +55,8 @@ logger = init_logger(__name__)
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
+ # --8<-- [end:unquantized_fused_moe]
+
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.unquantized_backend = select_unquantized_moe_backend(
@@ -90,8 +92,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- return self.forward_cuda(layer, x, topk_weights, topk_ids)
+ return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input)
@property
def is_monolithic(self) -> bool:
@@ -293,12 +296,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
layer=layer,
x=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
+ shared_experts_input=shared_experts_input,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
@@ -316,6 +321,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
@@ -329,6 +335,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
+ shared_experts_input=shared_experts_input,
)
def forward_monolithic_cuda(
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index 642088a45..5b7af3193 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -764,6 +764,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,
diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py
index 2fd567d7f..983c076bd 100644
--- a/vllm/model_executor/layers/quantization/bitsandbytes.py
+++ b/vllm/model_executor/layers/quantization/bitsandbytes.py
@@ -501,6 +501,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index 604373c0a..023cf3f67 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -349,6 +349,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
@@ -361,7 +362,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
- shared_experts_input=layer._get_shared_experts_input(x),
+ shared_experts_input=shared_experts_input,
)
@@ -645,6 +646,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported."
@@ -673,7 +675,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
- shared_experts_input=layer._get_shared_experts_input(x),
+ shared_experts_input=shared_experts_input,
)
@@ -1064,6 +1066,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_mk is not None
@@ -1079,7 +1082,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
- shared_experts_input=layer._get_shared_experts_input(x),
+ shared_experts_input=shared_experts_input,
)
@property
@@ -1203,6 +1206,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -1713,6 +1717,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel_backend == "Marlin"
return fused_marlin_moe(
@@ -1961,6 +1966,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -2575,6 +2581,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(
diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py
index 176bfe040..d971f3b5b 100644
--- a/vllm/model_executor/layers/quantization/experts_int8.py
+++ b/vllm/model_executor/layers/quantization/experts_int8.py
@@ -140,6 +140,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index b8040e894..279f97dd6 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -1010,6 +1010,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert not self.is_monolithic
@@ -1023,7 +1024,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
- shared_experts_input=layer._get_shared_experts_input(x),
+ shared_experts_input=shared_experts_input,
)
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index ce84d2521..f7d995598 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -635,6 +635,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
if layer.apply_router_weight_on_input:
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index d18c7207d..4c175fddb 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -900,6 +900,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 8b151133b..570317ad3 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -958,6 +958,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
@@ -980,7 +981,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
- shared_experts_input=layer._get_shared_experts_input(x),
+ shared_experts_input=shared_experts_input,
)
@@ -1524,6 +1525,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
@@ -1551,7 +1553,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
- shared_experts_input=layer._get_shared_experts_input(x),
+ shared_experts_input=shared_experts_input,
)
diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py
index bca2516d4..4365d1693 100644
--- a/vllm/model_executor/layers/quantization/moe_wna16.py
+++ b/vllm/model_executor/layers/quantization/moe_wna16.py
@@ -367,6 +367,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index d1c9cb6bb..13199124b 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -900,6 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
if layer.enable_eplb:
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 190890130..7faa4fcc9 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -419,6 +419,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
@@ -607,6 +608,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
@@ -977,6 +979,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
if (
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index 2b7d9ff29..635402f3d 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -816,10 +816,14 @@ class Worker(WorkerBase):
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
+ tp_size = get_tp_group().world_size
+ is_sequence_parallel = parallel_config.use_sequence_parallel_moe
+ sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
- tp_size_=get_tp_group().world_size,
+ tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
+ sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config