[MoE Refactor] Introduce MoERunner abstraction and move execution logic from FusedMoE to DefaultMoERunner (#32344)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-02-10 19:51:07 -05:00
committed by GitHub
parent dc6de33c3d
commit d1481ba783
25 changed files with 913 additions and 753 deletions

View File

@@ -32,7 +32,7 @@ th {
| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Subclass |
|---------|--------------------|--------------|---------------|-------|-----------------------|-----------|
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE.forward_impl] |
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE |
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |

View File

@@ -585,6 +585,7 @@ def make_modular_kernel(
tp_size_=get_tensor_model_parallel_world_size(),
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=1,
vllm_parallel_config=vllm_config.parallel_config,
)
@@ -594,6 +595,7 @@ def make_modular_kernel(
hidden_dim=config.K,
intermediate_size_per_partition=config.N,
num_local_experts=config.num_local_experts,
num_logical_experts=config.E,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),

View File

@@ -52,6 +52,7 @@ def make_dummy_moe_config(
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition,
num_local_experts=num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
in_dtype=in_dtype,

View File

@@ -913,12 +913,16 @@ class FusedMoEParallelConfig:
pcp_rank: int
dp_rank: int
ep_rank: int
sp_size: int
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
is_sequence_parallel: bool # whether sequence parallelism is used
enable_eplb: bool # whether to enable expert load balancing
@property
def is_sequence_parallel(self) -> bool:
return self.sp_size > 1
@property
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@@ -974,6 +978,7 @@ class FusedMoEParallelConfig:
tp_size_: int,
pcp_size_: int,
dp_size_: int,
sp_size_: int,
vllm_parallel_config: ParallelConfig,
) -> "FusedMoEParallelConfig":
"""
@@ -1073,9 +1078,9 @@ class FusedMoEParallelConfig:
dp_rank=dp_rank,
ep_size=1,
ep_rank=0,
sp_size=sp_size_,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
# DP + EP / TP + EP / DP + TP + EP
@@ -1093,9 +1098,9 @@ class FusedMoEParallelConfig:
dp_rank=dp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
sp_size=sp_size_,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
@@ -1111,10 +1116,10 @@ class FusedMoEParallelConfig:
dp_rank=0,
ep_size=1,
ep_rank=0,
sp_size=1,
use_ep=False,
all2all_backend="naive",
enable_eplb=False,
is_sequence_parallel=False,
)
@@ -1126,6 +1131,7 @@ class FusedMoEConfig:
hidden_dim: int
intermediate_size_per_partition: int
num_local_experts: int
num_logical_experts: int
activation: str
device: torch.device | str
routing_method: RoutingMethodType
@@ -1175,6 +1181,14 @@ class FusedMoEConfig:
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def sp_size(self):
return self.moe_parallel_config.sp_size
@property
def is_sequence_parallel(self):
return self.moe_parallel_config.is_sequence_parallel
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank

View File

@@ -121,17 +121,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def is_monolithic(self) -> bool:
return False
# @abstractmethod
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
# @abstractmethod
def apply_monolithic(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821

View File

@@ -89,6 +89,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
@@ -101,5 +102,5 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)

View File

@@ -1,13 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager, nullcontext
from collections.abc import Callable, Iterable
from enum import Enum
from typing import Literal, cast, get_args, overload
import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
@@ -16,17 +14,10 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
from vllm.forward_context import (
ForwardContext,
get_forward_context,
is_forward_context_available,
)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
@@ -47,6 +38,9 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
DefaultMoERunner,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
@@ -57,13 +51,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import (
aux_stream,
current_stream,
direct_register_custom_op,
)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
@@ -264,6 +252,7 @@ def maybe_roundup_hidden_size(
)
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
if (
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
@@ -273,6 +262,7 @@ def maybe_roundup_hidden_size(
current_platform.is_rocm()
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.MARLIN
):
hidden_size = round_up(hidden_size, 256)
@@ -338,29 +328,15 @@ class FusedMoE(CustomOp):
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
router_logits_dtype: torch.dtype | None = None,
has_shared_experts: bool = False,
gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None,
):
super().__init__()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.debug_once(
"Enabled separate cuda stream for MoE shared_experts", scope="local"
)
# For latent MoE: stores original hidden_states before routed_input_transform
# so shared_experts can use it for cloning (they need original dimension)
self._shared_experts_input: torch.Tensor | None = None
self._gate = gate
self._shared_experts = shared_experts
self._routed_input_transform = routed_input_transform
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -392,9 +368,12 @@ class FusedMoE(CustomOp):
tp_size_=tp_size_,
pcp_size_=pcp_size_,
dp_size_=dp_size_,
sp_size_=self.sp_size,
vllm_parallel_config=vllm_config.parallel_config,
)
assert self.moe_parallel_config.is_sequence_parallel == is_sequence_parallel
self.global_num_experts = num_experts + num_redundant_experts
self.logical_num_experts = num_experts
@@ -410,6 +389,7 @@ class FusedMoE(CustomOp):
self.layer_name = prefix
self.enable_eplb = enable_eplb
# TODO(bnell): should this be owned by router?
self.eplb_state = EplbLayerState()
self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy
@@ -506,7 +486,8 @@ class FusedMoE(CustomOp):
self.reduce_results = reduce_results
self.renormalize = renormalize
# TODO(bnell): these attributes are only used by cpu/xpu/mxfp4
# TODO(bnell): these attributes are only used by monolithic kernels.
# Put them in a MoERouterConfig dataclass?
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
@@ -565,6 +546,7 @@ class FusedMoE(CustomOp):
hidden_dim=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
num_local_experts=self.local_num_experts,
num_logical_experts=self.logical_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
router_logits_dtype=router_logits_dtype,
@@ -576,9 +558,9 @@ class FusedMoE(CustomOp):
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
# TODO: in_dtype == out_dtype?
disable_inplace=disable_inplace() or has_shared_experts,
disable_inplace=disable_inplace() or self._shared_experts is not None,
)
if self.use_mori_kernels:
if self.moe_config.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, (
"Mori needs to be used with aiter fused_moe for now."
)
@@ -641,9 +623,36 @@ class FusedMoE(CustomOp):
self.quant_method.create_weights(layer=self, **moe_quant_params)
# Chunked all2all staging tensor
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
# - we are using marlin kernels
backend = self.moe_parallel_config.all2all_backend
self.use_overlapped = (
not (
(self.enable_eplb and backend != "allgather_reducescatter")
or self.moe_parallel_config.use_fi_all2allv_kernels
)
and self._shared_experts is not None
)
self.runner = self._init_runner()
def _init_runner(self):
# Storing the runner in the FusedMoE is an intermediate state, eventually
# the runner will own the FusedMoE layer and provide the execution interface
# for MoE ops.
return DefaultMoERunner(
layer=self,
moe_config=self.moe_config,
router=self.router,
routed_input_transform=self._routed_input_transform,
gate=self.gate,
shared_experts=self.shared_experts,
quant_method=self.quant_method,
reduce_results=self.reduce_results,
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
)
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
@@ -673,10 +682,14 @@ class FusedMoE(CustomOp):
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
)
# We need to force reconstruction of runner because we're swapping out
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self.runner = self._init_runner()
@property
def shared_experts(self) -> torch.nn.Module | None:
return None
return self._shared_experts if self.use_overlapped else None
@property
def layer_id(self):
@@ -687,53 +700,12 @@ class FusedMoE(CustomOp):
@property
def gate(self) -> torch.nn.Module | None:
return None
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Hook to transform hidden_states before passing to routed experts.
For latent MoE: transforms [S, hidden_size] → [S, moe_latent_size].
The original hidden_states is saved in _shared_experts_input so
shared_experts still receive the original [S, hidden_size].
Override in subclasses (e.g., SharedFusedMoE) for latent MoE.
"""
return hidden_states
@contextmanager
def _set_shared_experts_input(
self, value: torch.Tensor | None
) -> Generator[None, None, None]:
"""Context manager to safely set/clear _shared_experts_input."""
self._shared_experts_input = value
try:
yield
finally:
self._shared_experts_input = None
def _get_shared_experts_input(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Get input for shared experts.
For latent MoE: shared_experts need original [S, hidden_size],
not the transformed [S, latent_size] used by routed experts.
"""
return (
self._shared_experts_input
if self._shared_experts_input is not None
else hidden_states
)
return self._gate
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@property
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def pcp_size(self):
return self.moe_parallel_config.pcp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@@ -742,14 +714,6 @@ class FusedMoE(CustomOp):
def tp_rank(self):
return self.moe_parallel_config.tp_rank
@property
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def pcp_rank(self):
return self.moe_parallel_config.pcp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@@ -758,39 +722,10 @@ class FusedMoE(CustomOp):
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels
@property
def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False)
@property
def use_dp_chunking(self) -> bool:
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels
or self.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return False
return self._gate is not None
def _maybe_init_expert_routing_tables(
self,
@@ -799,7 +734,7 @@ class FusedMoE(CustomOp):
# with DeepEP-ll all2all backend.
if (
self.expert_placement_strategy != "round_robin"
or not self.use_deepep_ll_kernels
or not self.moe_parallel_config.use_deepep_ll_kernels
):
return None
@@ -892,48 +827,6 @@ class FusedMoE(CustomOp):
dp_size=get_dp_group().world_size,
)
def _maybe_setup_shared_experts_stream(
self,
hidden_states: torch.Tensor,
has_separate_shared_experts: bool,
use_chunked_impl: bool,
) -> tuple[bool, torch.Tensor | None]:
use_shared_experts_stream = (
current_platform.is_cuda()
and has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
shared_experts_input = self._get_shared_experts_input(hidden_states)
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = shared_experts_input.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
return use_shared_experts_stream, hidden_states_clone
def _load_per_tensor_weight_scale(
self,
shard_id: str,
@@ -1191,7 +1084,7 @@ class FusedMoE(CustomOp):
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
if self.quant_method.__class__.__name__ in (
if quant_method_name in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
):
@@ -1488,7 +1381,7 @@ class FusedMoE(CustomOp):
assert all(
weight.is_contiguous()
for name, weight in weights
if not name.startswith("_shared_experts.")
if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
)
# Filter out the non-expert weights.
@@ -1538,32 +1431,6 @@ class FusedMoE(CustomOp):
self.ensure_moe_quant_config_init()
return self.quant_method.moe_quant_config
def ensure_dp_chunking_init(self):
if not self.use_dp_chunking or self.batched_hidden_states is not None:
return
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
moe = self.moe_config
if self.vllm_config.parallel_config.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.hidden_size)
logits_shape = (2, moe.max_num_tokens, self.logical_num_experts)
else:
states_shape = (moe.max_num_tokens, self.hidden_size)
logits_shape = (moe.max_num_tokens, self.logical_num_experts)
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
self.batched_router_logits = torch.zeros(
logits_shape,
dtype=moe.router_logits_dtype,
device=torch.cuda.current_device(),
)
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
@@ -1577,100 +1444,24 @@ class FusedMoE(CustomOp):
Therefore it is required that we reduce the shared_experts output
early.
"""
assert self.quant_method is not None
return (
isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
)
return self.runner.must_reduce_shared_expert_outputs()
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
if self.must_reduce_shared_expert_outputs():
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
return self.runner.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
original_hidden_states = hidden_states
original_hidden_dim = hidden_states.shape[-1]
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states)
# This is the dimension after transform (for routed expert output slicing)
transformed_hidden_dim = hidden_states.shape[-1]
if self.hidden_size != transformed_hidden_dim:
hidden_states = F.pad(
hidden_states,
(0, self.hidden_size - transformed_hidden_dim),
mode="constant",
value=0.0,
)
def reduce_output(states: torch.Tensor) -> torch.Tensor:
if (
not self.is_sequence_parallel
and not self.use_dp_chunking
and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)
):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return states
def encode_layer_name() -> str:
# Can be unavailable or None in unittests
if (
is_forward_context_available()
and get_forward_context().all_moe_layers is not None
):
return "from_forward_context"
return self.layer_name
if self.shared_experts is None:
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
fused_output = self.forward_impl(hidden_states, router_logits)
assert not isinstance(fused_output, tuple)
else:
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, encode_layer_name()
)
return reduce_output(fused_output)[..., :transformed_hidden_dim]
else:
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
with self._set_shared_experts_input(original_hidden_states):
shared_output, fused_output = self.forward_impl(
hidden_states, router_logits
)
else:
# Custom op handles setting/clearing _shared_experts_input internally
# We pass original tensor for shared experts (not transformed)
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states,
router_logits,
encode_layer_name(),
original_hidden_states,
)
# shared_output uses original dimension (before transform)
# fused_output uses transformed dimension (after transform)
return (
reduce_output(shared_output)[..., :original_hidden_dim],
reduce_output(fused_output)[..., :transformed_hidden_dim],
)
self.ensure_moe_quant_config_init()
return self.runner.forward(
hidden_states,
router_logits,
)
@property
def expert_map(self) -> torch.Tensor | None:
@@ -1685,312 +1476,6 @@ class FusedMoE(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits)
def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
# This is only true when DBO has been enabled in the config.
# Both tensors will have an outer dimension for the ubatch id
if self.batched_hidden_states.dim() == 3:
assert self.batched_router_logits.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
else:
batched_hidden_states = self.batched_hidden_states
batched_router_logits = self.batched_router_logits
assert (
batched_hidden_states.size(0) # type: ignore
>= chunk_size
)
assert (
batched_router_logits.size(0) # type: ignore
>= chunk_size
)
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
# Matrix multiply.
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=self,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=staged_hidden_states,
router_logits=staged_router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=self,
x=staged_hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
shared_output = self.shared_experts(staged_hidden_states)
final_hidden_states = (
shared_output,
final_hidden_states,
)
if not skip_result_store:
if self.shared_experts is None:
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states, non_blocking=True
)
else:
full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states[0], non_blocking=True
)
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states[1], non_blocking=True
)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(
max_tokens_across_dispatchers, self.sp_size
)
num_tokens = full_hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
):
chunk_start = chunk_start_
chunk_end = min(
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(
self.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
):
process_chunk(
chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
)
if self.shared_experts is None:
return full_fused_final_hidden_states
else:
return (full_shared_final_hidden_states, full_fused_final_hidden_states)
def forward_impl(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
self.ensure_moe_quant_config_init()
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None
)
use_chunked_impl = self.use_dp_chunking
use_shared_experts_stream, hidden_states_clone = (
self._maybe_setup_shared_experts_stream(
hidden_states, has_separate_shared_experts, use_chunked_impl
)
)
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
return self.forward_impl_chunked(
hidden_states, router_logits, has_separate_shared_experts
)
# NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = (
self.dp_size > 1 and not self.quant_method.supports_internal_mk
)
ctx = get_forward_context()
sp_ctx = (
ctx.dp_metadata.sp_local_sizes(self.sp_size)
if ctx.dp_metadata
else nullcontext()
)
with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine:
post_quant_allgather = (
self.quant_method is not None
and self.dp_size > 1
and self.use_ep
and getattr(self.quant_method, "do_post_quant_allgather", False)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
self, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch,
router_logits,
self.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
(
orig_hidden_states,
router_logits,
extra_tensors_combined,
) = dispatch_res
hidden_states_combined = (
orig_hidden_states,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
orig_hidden_states = hidden_states_combined
else:
orig_hidden_states = hidden_states
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_input = self._get_shared_experts_input(hidden_states)
shared_output = self.shared_experts(shared_input)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
# Matrix multiply.
x = hidden_states_combined if do_naive_dispatch_combine else hidden_states
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
# Figure out nicer way to do this.
x_orig = orig_hidden_states if do_naive_dispatch_combine else hidden_states
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=self,
x=x,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig,
router_logits=router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
)
if has_separate_shared_experts:
assert self.shared_experts is not None
if use_shared_experts_stream:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone)
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
)
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states, self.is_sequence_parallel)
if self.pcp_size > 1:
states = get_pcp_group().reduce_scatter(
states,
dim=0,
)
return states
if self.shared_experts is not None:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(
cls,
@@ -2051,94 +1536,6 @@ class FusedMoE(CustomOp):
return s
def get_layer_from_name(layer_name: str) -> FusedMoE:
forward_context: ForwardContext = get_forward_context()
if layer_name == "from_forward_context":
all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
if moe_layer_index >= len(all_moe_layers):
raise AssertionError(
"We expected the number of MOE layers in `all_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name = all_moe_layers[moe_layer_index]
forward_context.moe_layer_index += 1
self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
return self
def moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
self = get_layer_from_name(layer_name)
assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits)
def moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="moe_forward",
op_func=moe_forward,
mutates_args=["hidden_states"],
fake_impl=moe_forward_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_experts_input: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name)
assert self.shared_experts is not None
# Set here because torch.compile skips forward_native() setup code
# and calls this op directly. forward_impl() reads from this var.
with self._set_shared_experts_input(shared_experts_input):
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_experts_input: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]

View File

@@ -1228,7 +1228,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None = None,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,743 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import nullcontext
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.distributed import (
get_ep_group,
get_pcp_group,
tensor_model_parallel_all_reduce,
)
from vllm.forward_context import (
ForwardContext,
get_forward_context,
is_forward_context_available,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import (
aux_stream,
current_stream,
direct_register_custom_op,
)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
def get_layer_from_name(layer_name: str) -> torch.nn.Module:
forward_context: ForwardContext = get_forward_context()
if layer_name == "from_forward_context":
all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
if moe_layer_index >= len(all_moe_layers):
raise AssertionError(
"We expected the number of MOE layers in `all_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name = all_moe_layers[moe_layer_index]
forward_context.moe_layer_index += 1
return forward_context.no_compile_layers[layer_name]
def _moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
) -> torch.Tensor:
layer = get_layer_from_name(layer_name)
return layer.runner.forward_impl(
layer, hidden_states, router_logits, shared_experts_input
)
def _moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(layer_name)
return layer.runner.forward_impl(
layer, hidden_states, router_logits, shared_experts_input
)
def _moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as
# hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward",
op_func=_moe_forward,
mutates_args=["hidden_states"],
fake_impl=_moe_forward_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=_moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=_moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
class DefaultMoERunner(MoERunner):
"""
Default implementation of the MoE runner for executing Mixture of Experts layers.
This class provides a comprehensive implementation for running MoE computations
with support for:
- Expert routing and token dispatching
- Shared experts computation with optional parallel execution using CUDA streams
- Data parallel (DP) chunking for large batch processing
- Tensor model parallel and expert parallel operations
- Various quantization methods and custom operators
- Both monolithic and decomposed expert execution paths
The runner handles the complete MoE forward pass including routing tokens to
experts, executing expert computations, and combining results. It supports
advanced features like overlapped execution of shared experts and optimized
kernels for different parallel execution modes.
Eventually, this class will be split up and specialized for different
configurations, e.g. the presense or absence of shared experts, a gate, etc.
"""
def __init__(
self,
layer: torch.nn.Module,
moe_config: FusedMoEConfig,
router: FusedMoERouter,
routed_input_transform: torch.nn.Module | None,
gate: torch.nn.Module | None,
shared_experts: torch.nn.Module | None,
quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool,
):
super().__init__()
self.moe_config = moe_config
self.router = router
self.routed_input_transform = routed_input_transform
self.gate = gate
self.shared_experts = shared_experts
self.quant_method = quant_method
self.reduce_results = reduce_results
self.enable_dbo = enable_dbo
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.debug_once(
"Enabled separate cuda stream for MoE shared_experts", scope="local"
)
# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer.layer_name
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
if self.shared_experts is None:
self.moe_forward = _moe_forward
else:
self.moe_forward = _moe_forward_shared
else:
if self.shared_experts is None:
self.moe_forward = torch.ops.vllm.moe_forward
else:
self.moe_forward = torch.ops.vllm.moe_forward_shared
# Chunked all2all staging tensor
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
@property
def use_dp_chunking(self) -> bool:
return (
self.moe_config.moe_parallel_config.use_pplx_kernels
or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_mori_kernels
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
def _maybe_setup_shared_experts_stream(
self,
hidden_states: torch.Tensor,
shared_input: torch.Tensor | None,
has_separate_shared_experts: bool,
use_chunked_impl: bool,
) -> tuple[bool, torch.Tensor | None]:
use_shared_experts_stream = (
current_platform.is_cuda()
and has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
shared_experts_input = (
shared_input if shared_input is not None else hidden_states
)
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = shared_experts_input.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
return use_shared_experts_stream, hidden_states_clone
def ensure_dp_chunking_init(self):
if not self.use_dp_chunking or self.batched_hidden_states is not None:
return
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
moe = self.moe_config
if self.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
else:
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
self.batched_router_logits = torch.zeros(
logits_shape,
dtype=moe.router_logits_dtype,
device=torch.cuda.current_device(),
)
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
assert self.quant_method is not None
return (
self.quant_method.moe_mk is not None
and self.quant_method.moe_mk.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
if self.must_reduce_shared_expert_outputs():
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
"""
if self.routed_input_transform is not None:
result = self.routed_input_transform(hidden_states)
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
if isinstance(result, tuple):
return result[0]
return result
return hidden_states
def _reduce_output(
self,
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
trunc_sizes: list[int],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return x[..., :trunc_size]
def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)
if (
not self.moe_config.is_sequence_parallel
and not self.use_dp_chunking
and self.reduce_results
and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
):
func = reduce_and_trunc
else:
func = trunc
if isinstance(states, tuple):
return tuple(
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
)
else:
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])
def _encode_layer_name(self) -> str:
# Can be unavailable or None in unittests
if (
is_forward_context_available()
and get_forward_context().all_moe_layers is not None
):
return "from_forward_context"
return self.layer_name
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
original_hidden_states = hidden_states
original_hidden_dim = hidden_states.shape[-1]
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states)
# This is the dimension after transform (for routed expert output slicing)
transformed_hidden_dim = hidden_states.shape[-1]
if self.moe_config.hidden_dim != transformed_hidden_dim:
hidden_states = F.pad(
hidden_states,
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
mode="constant",
value=0.0,
)
fused_output = self.moe_forward(
hidden_states,
router_logits,
original_hidden_states,
self._encode_layer_name(),
)
if isinstance(fused_output, tuple):
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
else:
orig_hidden_dims = [transformed_hidden_dim]
return self._reduce_output(fused_output, orig_hidden_dims)
def forward_impl_chunked(
self,
layer: torch.nn.Module,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
# TODO(bnell): Fix shared_expert_inputs w/chunking.
# assert shared_input is None, (
# "Routed input transform is not currently supported with DP chunking."
# )
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
# This is only true when DBO has been enabled in the config.
# Both tensors will have an outer dimension for the ubatch id
if self.batched_hidden_states.dim() == 3:
assert self.batched_router_logits.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
else:
batched_hidden_states = self.batched_hidden_states
batched_router_logits = self.batched_router_logits
assert (
batched_hidden_states.size(0) # type: ignore
>= chunk_size
)
assert (
batched_router_logits.size(0) # type: ignore
>= chunk_size
)
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
# Matrix multiply.
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=staged_hidden_states,
router_logits=staged_router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=layer,
x=staged_hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
)
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
shared_output = self.shared_experts(staged_hidden_states)
final_hidden_states = (
shared_output,
final_hidden_states,
)
if not skip_result_store:
if self.shared_experts is None:
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states, non_blocking=True
)
else:
full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states[0], non_blocking=True
)
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states[1], non_blocking=True
)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.moe_config.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(
max_tokens_across_dispatchers, self.moe_config.sp_size
)
num_tokens = full_hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
):
chunk_start = chunk_start_
chunk_end = min(
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
):
process_chunk(
chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
)
if self.shared_experts is None:
return full_fused_final_hidden_states
else:
return (full_shared_final_hidden_states, full_fused_final_hidden_states)
def forward_impl(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None
)
use_chunked_impl = self.use_dp_chunking
use_shared_experts_stream, hidden_states_clone = (
self._maybe_setup_shared_experts_stream(
hidden_states,
shared_input,
has_separate_shared_experts,
use_chunked_impl,
)
)
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
return self.forward_impl_chunked(
layer,
hidden_states,
router_logits,
shared_input,
has_separate_shared_experts,
)
# NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = (
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
)
ctx = get_forward_context()
sp_ctx = (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
if ctx.dp_metadata
else nullcontext()
)
with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine:
post_quant_allgather = (
self.quant_method is not None
and self.moe_config.dp_size > 1
and self.moe_config.use_ep
and getattr(self.quant_method, "do_post_quant_allgather", False)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
layer, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch,
router_logits,
self.moe_config.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
(
orig_hidden_states,
router_logits,
extra_tensors_combined,
) = dispatch_res
hidden_states_combined = (
orig_hidden_states,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
orig_hidden_states = hidden_states_combined
else:
orig_hidden_states = hidden_states
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_input = (
shared_input if shared_input is not None else hidden_states
)
shared_output = self.shared_experts(shared_input)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
# Figure out nicer way to do this.
if do_naive_dispatch_combine:
x = hidden_states_combined
x_orig = orig_hidden_states
else:
x = hidden_states
x_orig = hidden_states
# Matrix multiply.
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=x,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig,
router_logits=router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=layer,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
)
if has_separate_shared_experts:
assert self.shared_experts is not None
if use_shared_experts_stream:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone)
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
)
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(
states, self.moe_config.is_sequence_parallel
)
if self.moe_config.pcp_size > 1:
states = get_pcp_group().reduce_scatter(
states,
dim=0,
)
return states
if self.shared_experts is not None:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)

View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
class MoERunner(ABC):
"""
Abstract base class for Mixture of Experts (MoE) runners.
This class defines the interface that all MoE runner implementations must follow.
MoE runners are responsible for executing the forward pass of MoE layers, handling
expert routing, and managing tensor parallel operations.
"""
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def must_reduce_shared_expert_outputs(self) -> bool:
raise NotImplementedError
@abstractmethod
def maybe_all_reduce_tensor_model_parallel(
self,
final_hidden_states: torch.Tensor,
):
raise NotImplementedError

View File

@@ -18,70 +18,6 @@ class SharedFusedMoE(FusedMoE):
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
routed_input_transform: torch.nn.Module | None = None,
**kwargs,
):
# Pass has_shared_experts so FusedMoE.__init__ can set disable_inplace
# without accessing self.shared_experts (submodules cannot be set before
# Module.__init__()).
kwargs["has_shared_experts"] = shared_experts is not None
super().__init__(**kwargs)
self._shared_experts = shared_experts
self._routed_input_transform = routed_input_transform
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
# - we are using marlin kernels
backend = self.moe_parallel_config.all2all_backend
self.use_overlapped = (
use_overlapped
and not (
(self.enable_eplb and backend != "allgather_reducescatter")
or self.moe_parallel_config.use_fi_all2allv_kernels
)
and self._shared_experts is not None
)
self._gate = gate
@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None
@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
@property
def is_internal_router(self) -> bool:
return self.gate is not None
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
"""
if self._routed_input_transform is not None:
result = self._routed_input_transform(hidden_states)
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
if isinstance(result, tuple):
return result[0]
return result
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,

View File

@@ -55,6 +55,8 @@ logger = init_logger(__name__)
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
# --8<-- [end:unquantized_fused_moe]
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.unquantized_backend = select_unquantized_moe_backend(
@@ -90,8 +92,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_cuda(layer, x, topk_weights, topk_ids)
return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input)
@property
def is_monolithic(self) -> bool:
@@ -293,12 +296,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
layer=layer,
x=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_experts_input,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
@@ -316,6 +321,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
@@ -329,6 +335,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)
def forward_monolithic_cuda(

View File

@@ -764,6 +764,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,

View File

@@ -501,6 +501,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts

View File

@@ -349,6 +349,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
@@ -361,7 +362,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
@@ -645,6 +646,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported."
@@ -673,7 +675,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
@@ -1064,6 +1066,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_mk is not None
@@ -1079,7 +1082,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
@property
@@ -1203,6 +1206,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -1713,6 +1717,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel_backend == "Marlin"
return fused_marlin_moe(
@@ -1961,6 +1966,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -2575,6 +2581,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(

View File

@@ -140,6 +140,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts

View File

@@ -1010,6 +1010,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert not self.is_monolithic
@@ -1023,7 +1024,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)

View File

@@ -635,6 +635,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
if layer.apply_router_weight_on_input:

View File

@@ -900,6 +900,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,

View File

@@ -958,6 +958,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
@@ -980,7 +981,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)
@@ -1524,6 +1525,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
@@ -1551,7 +1553,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
shared_experts_input=shared_experts_input,
)

View File

@@ -367,6 +367,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts

View File

@@ -900,6 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
if layer.enable_eplb:

View File

@@ -419,6 +419,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
@@ -607,6 +608,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
@@ -977,6 +979,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
if (

View File

@@ -816,10 +816,14 @@ class Worker(WorkerBase):
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config