[MoE Refactor] Split of DefaultMoERunner class (#35326)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
bnellnm
2026-04-06 12:41:59 -04:00
committed by GitHub
parent 608914de30
commit 93bada494f
8 changed files with 868 additions and 705 deletions

View File

@@ -39,8 +39,8 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
DefaultMoERunner,
from vllm.model_executor.layers.fused_moe.runner.moe_runner_factory import (
create_moe_runner,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
@@ -572,8 +572,8 @@ class FusedMoE(CustomOp):
# Storing the runner in the FusedMoE is an intermediate state, eventually
# the runner will own the FusedMoE layer and provide the execution interface
# for MoE ops.
self.runner = DefaultMoERunner(
layer=self,
self.runner = create_moe_runner(
layer_name=self.layer_name,
moe_config=self.moe_config,
router=self.router,
routed_input_transform=self._routed_input_transform,

View File

@@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.forward_context import (
get_forward_context,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
from vllm.v1.worker.workspace import current_workspace_manager
class ChunkingMoERunner(MoERunnerBase):
"""
MoE runner wrapper that adds chunked processing to any MoERunnerBase.
This runner wraps an inner MoERunnerBase and overrides _forward_impl to
process large batches by breaking them into smaller chunks. Each chunk
is delegated to the inner runner's _forward_impl, making chunking
composable with any runner implementation.
All MoERunnerBase state (moe_config, router, quant_method, etc.) is
transparently delegated to the inner runner via __getattr__.
ChunkingMoERunner only owns chunking-specific state: the pre-allocated
workspace buffers and the reduce_results override.
Key behaviors:
- Pre-allocates workspace tensors for CUDA graph compatibility
- Processes chunks via inner._forward_impl per chunk
- Never reduces results (reduce_results always returns False)
"""
def __init__(self, inner: MoERunnerBase):
# Assert that _maybe_dispatch/_maybe_combine will be nops.
assert inner.moe_config.pcp_size == 1
# Skip MoERunnerBase.__init__ — all state is delegated to inner
# via __getattr__. Only chunking-specific state lives here.
self._inner = inner
# Pre-allocated staging buffers. These need to exist ahead of time
# due to CUDA graph construction needing fixed buffer addresses.
self.batched_hidden_states, self.batched_router_logits = (
self._init_dp_chunking()
)
def __getattr__(self, name):
# Delegate attribute access to the inner runner. This is only
# called when normal lookup (instance __dict__, class MRO) fails,
# so ChunkingMoERunner's own attributes and methods take priority.
return getattr(self._inner, name)
@property
def shared_experts(self) -> SharedExperts | None:
return self._inner.shared_experts
# TODO(bnell): temporary hack, do not call this method.
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
self._inner._replace_quant_method(quant_method)
self.quant_method = quant_method
def is_internal_router(self) -> bool:
return self._inner.gate is not None
# Reducing results when chunking is handled by the MK finalize operations
# when DP chunking is enabled..
# This will be removed by #35949
@property
def reduce_results(self) -> bool:
return False
def _init_dp_chunking(self) -> list[torch.Tensor]:
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
moe = self.moe_config
if self.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
else:
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
# Does this need some kind of profiling run check like modular_kernel.py?
return current_workspace_manager().get_simultaneous(
(states_shape, moe.in_dtype),
(logits_shape, moe.router_logits_dtype),
)
def _allocate_dp_chunking_outputs(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> tuple[torch.Tensor | None, torch.Tensor]:
# Assert the inputs are of the proper type and shape.
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == router_logits.size(-1)
final_fused_hidden_states = torch.empty_like(hidden_states)
if self.shared_experts is not None:
if shared_experts_input is not None:
final_shared_hidden_states = torch.empty_like(shared_experts_input)
else:
final_shared_hidden_states = torch.empty_like(hidden_states)
else:
final_shared_hidden_states = None
return final_shared_hidden_states, final_fused_hidden_states
def _slice_and_copy_input(
self,
out_slice: torch.Tensor,
orig: torch.Tensor | None,
start: int,
end: int,
) -> torch.Tensor:
assert orig is not None
slice_size = end - start
orig_slice = orig[start:end, :]
if self.enable_dbo:
assert out_slice.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
out_slice = out_slice[batch_buffer_idx, :]
assert out_slice.size(0) >= slice_size
out_slice = out_slice[:slice_size, :]
out_slice.copy_(orig_slice, non_blocking=True)
return out_slice
def _forward_impl(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
final_shared_hidden_states, final_fused_hidden_states = (
self._allocate_dp_chunking_outputs(
hidden_states, router_logits, shared_experts_input
)
)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.moe_config.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(
max_tokens_across_dispatchers, self.moe_config.sp_size
)
num_tokens = hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
):
chunk_start = chunk_start_
chunk_end = min(
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
chunk_sizes = ctx.dp_metadata.chunked_sizes(
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
)
with chunk_sizes:
hidden_states_chunk = self._slice_and_copy_input(
self.batched_hidden_states,
hidden_states,
chunk_start,
chunk_end,
)
router_logits_chunk = self._slice_and_copy_input(
self.batched_router_logits,
router_logits,
chunk_start,
chunk_end,
)
shared_experts_input_chunk = (
shared_experts_input[chunk_start:chunk_end, :]
if shared_experts_input is not None
else None
)
# Delegate per-chunk computation to the inner runner.
chunk_result = self._inner._forward_impl(
layer=layer,
hidden_states=hidden_states_chunk,
router_logits=router_logits_chunk,
shared_experts_input=shared_experts_input_chunk,
)
# Store outputs
# TODO(bnell): document when chunk_start >= num_tokens
if chunk_start < num_tokens:
if self.shared_experts is not None:
assert isinstance(chunk_result, tuple)
shared_output_chunk, hidden_states_chunk = chunk_result
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
hidden_states_chunk, non_blocking=True
)
assert shared_output_chunk is not None
assert final_shared_hidden_states is not None
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
shared_output_chunk, non_blocking=True
)
else:
assert isinstance(chunk_result, torch.Tensor)
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
chunk_result, non_blocking=True
)
if self.shared_experts is None:
return final_fused_hidden_states
else:
assert final_shared_hidden_states is not None
return (final_shared_hidden_states, final_fused_hidden_states)

View File

@@ -1,516 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from vllm.distributed import (
get_ep_group,
get_pcp_group,
tensor_model_parallel_all_reduce,
)
from vllm.forward_context import (
ForwardContext,
get_forward_context,
is_forward_context_available,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
SharedExpertsOrder,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import (
HAS_OPAQUE_TYPE,
ModuleName,
direct_register_custom_op,
)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase
def get_layer_from_name(layer_name: str) -> torch.nn.Module:
forward_context: ForwardContext = get_forward_context()
if layer_name == "from_forward_context":
all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
if moe_layer_index >= len(all_moe_layers):
raise AssertionError(
"We expected the number of MOE layers in `all_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name = all_moe_layers[moe_layer_index]
forward_context.moe_layer_index += 1
return forward_context.no_compile_layers[layer_name]
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
# on older versions it remains a plain str.
if TYPE_CHECKING:
from typing import TypeAlias
_layer_name_type: TypeAlias = str | ModuleName
else:
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
# Note: _moe_forward and _moe_forward_shared should not contain any
# implementation details, They should merely pass along control to
# the runner's 'forward_dispatch' method.
def _moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: _layer_name_type,
) -> torch.Tensor:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner.forward_dispatch(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: _layer_name_type,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner.forward_dispatch(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as
# hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward",
op_func=_moe_forward,
mutates_args=["hidden_states"], # is this still true?
fake_impl=_moe_forward_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=_moe_forward_shared,
fake_impl=_moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
class DefaultMoERunner(MoERunner):
class DefaultMoERunner(MoERunnerBase):
"""
Default implementation of the MoE runner for executing Mixture of Experts layers.
Standard MoE runner implementation for executing Mixture of Experts layers.
This class provides a comprehensive implementation for running MoE computations
with support for:
- Expert routing and token dispatching
This is the primary concrete implementation of MoE execution logic, providing
comprehensive support for standard MoE operations. It handles:
- Expert routing and token dispatching using various routing strategies
- Shared experts computation with optional parallel execution using CUDA streams
- Data parallel (DP) chunking for large batch processing
- Tensor model parallel and expert parallel operations
- Various quantization methods and custom operators
- Multiple quantization methods and optimized kernel selection
- Both monolithic and decomposed expert execution paths
- Integration with various parallel execution modes (TP, EP, DP)
The runner handles the complete MoE forward pass including routing tokens to
experts, executing expert computations, and combining results. It supports
advanced features like overlapped execution of shared experts and optimized
kernels for different parallel execution modes.
The runner orchestrates the complete MoE forward pass including routing tokens
to experts, executing expert computations in parallel, and combining results.
It supports advanced features like overlapped execution of shared experts,
optimized kernels for different parallel configurations, and seamless
integration with vLLM's distributed execution framework.
Eventually, this class will be split up and specialized for different
configurations, e.g. the presence or absence of shared experts, a gate, etc.
This implementation is suitable for most standard MoE use cases. For specialized
scenarios like large batch chunking, alternative runners like ChunkingMoERunner
may be more appropriate.
Eventually, this class may be split into more specialized implementations
for different configurations (e.g., with/without shared experts, gates, etc.).
"""
def __init__(
self,
layer: torch.nn.Module,
moe_config: FusedMoEConfig,
router: FusedMoERouter,
routed_input_transform: torch.nn.Module | None,
gate: torch.nn.Module | None,
shared_experts: torch.nn.Module | None,
quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool,
):
super().__init__()
self.moe_config = moe_config
self.router = router
self.routed_input_transform = routed_input_transform
self.gate = gate
self.quant_method = quant_method
self.reduce_results = reduce_results
self.enable_dbo = enable_dbo
self.shared_experts: SharedExperts | None = None
if shared_experts is not None:
self.shared_experts = SharedExperts(
shared_experts,
moe_config=moe_config,
# Note: For now we must pass quant_method along to SharedExperts so it
# can property determine where the shared experts are supposed to be
# called, i.e. by a MK or by the MoERunner.
# Once the MK can be created upfront, we can just pass in the proper
# flags derived from the quant_method's MK.
reduce_results=reduce_results,
quant_method=quant_method,
enable_dbo=enable_dbo,
)
# Chunked all2all staging tensor
# These need to exist ahead of time due to CUDAgraph construction
# needing a fixed buffer address.
self.use_dp_chunking = self.moe_config.moe_parallel_config.use_dp_chunking
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
self._maybe_init_dp_chunking()
# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer.layer_name
self.forward_entry, self.forward_impl = self._select_forward(layer)
def _select_forward(self, layer: torch.nn.Module) -> tuple[Callable, Callable]:
# Select implementation based on presence of DP chunking.
forward_impl_fn = (
self._forward_impl_chunked if self.use_dp_chunking else self._forward_impl
)
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
return (
_moe_forward if self.shared_experts is None else _moe_forward_shared,
forward_impl_fn,
)
return (
torch.ops.vllm.moe_forward
if self.shared_experts is None
else torch.ops.vllm.moe_forward_shared,
forward_impl_fn,
)
# TODO(bnell): temporary hack, do not call this method.
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
if self.shared_experts is not None:
self.shared_experts._quant_method = quant_method
self.quant_method = quant_method
def is_internal_router(self) -> bool:
return self.gate is not None
def _maybe_init_dp_chunking(self):
if not self.use_dp_chunking:
return
assert self.batched_hidden_states is None
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
moe = self.moe_config
if self.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
else:
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
device = torch.accelerator.current_device_index()
self.batched_hidden_states = torch.zeros(
states_shape,
dtype=moe.in_dtype,
device=device,
)
self.batched_router_logits = torch.zeros(
logits_shape,
dtype=moe.router_logits_dtype,
device=device,
)
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return (
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
if self.must_reduce_shared_expert_outputs():
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def apply_routed_input_transform(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
Returns (possibly transformed) hidden states and the input for shared
experts (or None if there are no shared experts).
"""
if self.routed_input_transform is not None:
result = self.routed_input_transform(hidden_states)
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
if isinstance(result, tuple):
return result[0], hidden_states
return result, hidden_states
return (
hidden_states,
hidden_states if self.shared_experts is not None else None,
)
def _maybe_reduce_output(
self,
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
trunc_sizes: list[int],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return x[..., :trunc_size]
def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)
if (
not self.moe_config.is_sequence_parallel
and not self.use_dp_chunking
and self.reduce_results
and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
):
func = reduce_and_trunc
else:
func = trunc
if isinstance(states, tuple):
return tuple(
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
)
else:
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])
def _encode_layer_name(self) -> str | ModuleName:
if HAS_OPAQUE_TYPE:
return ModuleName(self.layer_name)
# Can be unavailable or None in unittests
if (
is_forward_context_available()
and get_forward_context().all_moe_layers is not None
):
return "from_forward_context"
return self.layer_name
def _maybe_pad_hidden_states(
self,
shared_experts_input: torch.Tensor | None,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, list[int]]:
shared_experts_hidden_dim = (
shared_experts_input.shape[-1] if shared_experts_input is not None else 0
)
transformed_hidden_dim = hidden_states.shape[-1]
if (
not self.quant_method.skip_forward_padding
and self.moe_config.hidden_dim != transformed_hidden_dim
):
hidden_states = F.pad(
hidden_states,
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
mode="constant",
value=0.0,
)
if self.shared_experts is not None:
orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim]
else:
orig_hidden_dims = [transformed_hidden_dim]
return hidden_states, orig_hidden_dims
def _maybe_apply_shared_experts(
self,
shared_experts_input: torch.Tensor | None,
order: SharedExpertsOrder,
):
if self.shared_experts is not None:
assert shared_experts_input is not None
self.shared_experts.apply(shared_experts_input, order)
def _apply_quant_method(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> tuple[torch.Tensor | None, torch.Tensor]:
# Run this before quant_method to avoid inplace issues.
# TODO(bnell): probably not needed anymore since inplace is
# disabled when shared experts are present.
self._maybe_apply_shared_experts(
shared_experts_input, SharedExpertsOrder.NO_OVERLAP
)
if self.quant_method.is_monolithic:
fused_out = self.quant_method.apply_monolithic(
layer=layer,
x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
# Passing shared_experts_input in case SharedExpertsOrder is
# NO_OVERLAP or MK_INTERNAL_OVERLAPPED.
fused_out = self.quant_method.apply(
layer=layer,
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_experts_input,
)
self._maybe_apply_shared_experts(
shared_experts_input,
SharedExpertsOrder.MULTI_STREAM_OVERLAPPED,
)
return (
self.shared_experts.output if self.shared_experts is not None else None,
fused_out,
)
def _sequence_parallel_context(self):
ctx = get_forward_context()
return (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
if ctx.dp_metadata
else nullcontext()
)
def _allocate_dp_chunking_outputs(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor | None, torch.Tensor]:
assert self.use_dp_chunking
# Assert the inputs are of the proper type and shape.
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == router_logits.size(-1)
final_fused_hidden_states = torch.empty_like(hidden_states)
if self.shared_experts is not None:
final_shared_hidden_states = torch.empty_like(hidden_states)
else:
final_shared_hidden_states = None
return final_shared_hidden_states, final_fused_hidden_states
def _maybe_sync_shared_experts_stream(
self,
shared_experts_input: torch.Tensor | None,
):
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.shared_experts is not None:
self.shared_experts.maybe_sync_shared_experts_stream(shared_experts_input)
@property
def reduce_results(self) -> bool:
return self._reduce_results
@property
def do_naive_dispatch_combine(self) -> bool:
@@ -572,195 +101,6 @@ class DefaultMoERunner(MoERunner):
else:
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Invoke the fused moe layer.
Input:
- hidden_states
- router_logits
Output:
- The new hidden_states.
or
- A tuple of (shared experts output, new hidden_states).
Calling sequence
- forward
- self.forward_entry (_moe_forward or _moe_forward_shared custom op)
- forward_dispatch
- forward_impl (_forward_impl or _forward_impl_chunked)
Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
to the following reasons:
1. the chunking loop in _forward_impl_chunked cannot be compiled by
torch.compile
2. pytorch cannot handle union types in custom op signatures so _moe_forward
and _moe_forward_shared must be split.
If _forward_impl_chunked can be implemented via torch.scan we can potentially
get rid of _moe_forward and _moe_forward_shared and collapse the whole sequence
into the 'forward' method.
"""
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states, shared_experts_input = self.apply_routed_input_transform(
hidden_states
)
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
shared_experts_input,
hidden_states,
)
fused_output = self.forward_entry(
hidden_states,
router_logits,
shared_experts_input,
self._encode_layer_name(),
)
return self._maybe_reduce_output(fused_output, og_hidden_dims)
def forward_dispatch(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
# Sync aux and main stream for shared expert multi-stream overlap.
self._maybe_sync_shared_experts_stream(shared_experts_input)
# If the Runner holds the gate, apply it after the stream sync,
# so it can run overlapped with the
# NOTE: in future PR, MoE runner will always hold the gate.
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
self._maybe_apply_shared_experts(
shared_experts_input,
SharedExpertsOrder.EXTERNAL,
)
with self._sequence_parallel_context():
return self.forward_impl(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _slice_and_copy_input(
self,
out_slice: torch.Tensor,
orig: torch.Tensor | None,
start: int,
end: int,
) -> torch.Tensor:
assert orig is not None
slice_size = end - start
orig_slice = orig[start:end, :]
if self.enable_dbo:
assert out_slice.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
out_slice = out_slice[batch_buffer_idx, :]
assert out_slice.size(0) >= slice_size
out_slice = out_slice[:slice_size, :]
out_slice.copy_(orig_slice, non_blocking=True)
return out_slice
def _forward_impl_chunked(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
final_shared_hidden_states, final_fused_hidden_states = (
self._allocate_dp_chunking_outputs(hidden_states, router_logits)
)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.moe_config.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(
max_tokens_across_dispatchers, self.moe_config.sp_size
)
num_tokens = hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
):
chunk_start = chunk_start_
chunk_end = min(
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
chunk_sizes = ctx.dp_metadata.chunked_sizes(
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
)
with chunk_sizes:
hidden_states_chunk = self._slice_and_copy_input(
self.batched_hidden_states,
hidden_states,
chunk_start,
chunk_end,
)
router_logits_chunk = self._slice_and_copy_input(
self.batched_router_logits,
router_logits,
chunk_start,
chunk_end,
)
shared_experts_input_chunk = (
shared_experts_input[chunk_start:chunk_end, :]
if shared_experts_input is not None
else None
)
shared_output_chunk, hidden_states_chunk = self._apply_quant_method(
layer=layer,
hidden_states=hidden_states_chunk,
router_logits=router_logits_chunk,
shared_experts_input=shared_experts_input_chunk,
)
# Store outputs
# TODO(bnell): document when chunk_start >= num_tokens
if chunk_start < num_tokens:
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
hidden_states_chunk, non_blocking=True
)
if self.shared_experts is not None:
assert shared_output_chunk is not None
assert final_shared_hidden_states is not None
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
shared_output_chunk, non_blocking=True
)
if self.shared_experts is None:
return final_fused_hidden_states
else:
assert final_shared_hidden_states is not None
return (final_shared_hidden_states, final_fused_hidden_states)
def _forward_impl(
self,
layer: torch.nn.Module,

View File

@@ -4,6 +4,13 @@ from abc import ABC, abstractmethod
import torch
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
class MoERunner(ABC):
"""
@@ -36,3 +43,13 @@ class MoERunner(ABC):
@abstractmethod
def is_internal_router(self) -> bool:
raise NotImplementedError
@property
@abstractmethod
def shared_experts(self) -> SharedExperts | None:
raise NotImplementedError
# TODO(bnell): temporary hack, do not call this method.
@abstractmethod
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
raise NotImplementedError

View File

@@ -0,0 +1,527 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Callable
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.forward_context import (
ForwardContext,
get_forward_context,
is_forward_context_available,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
SharedExpertsOrder,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
HAS_OPAQUE_TYPE,
ModuleName,
direct_register_custom_op,
)
def get_layer_from_name(layer_name: str) -> torch.nn.Module:
forward_context: ForwardContext = get_forward_context()
if layer_name == "from_forward_context":
all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
if moe_layer_index >= len(all_moe_layers):
raise AssertionError(
"We expected the number of MOE layers in `all_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name = all_moe_layers[moe_layer_index]
forward_context.moe_layer_index += 1
return forward_context.no_compile_layers[layer_name]
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
# on older versions it remains a plain str.
if TYPE_CHECKING:
from typing import TypeAlias
_layer_name_type: TypeAlias = str | ModuleName
else:
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
# Note: _moe_forward and _moe_forward_shared should not contain any
# implementation details, They should merely pass along control to
# the runner's 'forward_dispatch' method.
def _moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: _layer_name_type,
) -> torch.Tensor:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner.forward_dispatch(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: _layer_name_type,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner.forward_dispatch(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as
# hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward",
op_func=_moe_forward,
mutates_args=["hidden_states"], # is this still true?
fake_impl=_moe_forward_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=_moe_forward_shared,
fake_impl=_moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
class MoERunnerBase(MoERunner):
"""
Abstract base class providing common functionality for MoE runner implementations.
This class serves as the foundation for concrete MoE runner implementations by
providing shared state management and common utilities. It handles:
- Common initialization and configuration management
- Shared expert output reduction logic for tensor parallel scenarios
- Base methods for tensor model parallel reductions
- Common properties and utility functions used across different runner types
Concrete subclasses must implement the abstract methods to define their specific
execution strategies, such as standard execution, chunked processing, or other
specialized approaches. The base class provides the infrastructure while
allowing flexibility in the actual MoE computation implementation.
Key abstract methods that subclasses must implement:
- reduce_results: Determines whether results should be reduced across ranks
- _forward_impl: The core MoE computation logic specific to each runner type
"""
def __init__(
self,
layer_name: str,
moe_config: FusedMoEConfig,
router: FusedMoERouter,
routed_input_transform: torch.nn.Module | None,
gate: torch.nn.Module | None,
shared_experts: torch.nn.Module | None,
quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool,
):
super().__init__()
self.moe_config = moe_config
self.router = router
self.routed_input_transform = routed_input_transform
self.gate = gate
self.quant_method = quant_method
self._reduce_results = reduce_results
self.enable_dbo = enable_dbo
self._shared_experts: SharedExperts | None = None
if shared_experts is not None:
self._shared_experts = SharedExperts(
shared_experts,
moe_config=moe_config,
# Note: For now we must pass quant_method along to SharedExperts so it
# can property determine where the shared experts are supposed to be
# called, i.e. by a MK or by the MoERunner.
# Once the MK can be created upfront, we can just pass in the proper
# flags derived from the quant_method's MK.
reduce_results=reduce_results,
quant_method=quant_method,
enable_dbo=enable_dbo,
)
# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer_name
self.forward_entry = self._select_forward()
def _select_forward(self) -> Callable:
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped _forward_impl.
return _moe_forward if self._shared_experts is None else _moe_forward_shared
return (
torch.ops.vllm.moe_forward
if self._shared_experts is None
else torch.ops.vllm.moe_forward_shared
)
@property
def shared_experts(self) -> SharedExperts | None:
return self._shared_experts
# TODO(bnell): temporary hack, do not call this method.
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
if self._shared_experts is not None:
self._shared_experts._quant_method = quant_method
self.quant_method = quant_method
def is_internal_router(self) -> bool:
return self.gate is not None
@property
@abstractmethod
def reduce_results(self) -> bool:
raise NotImplementedError
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return (
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
if self.must_reduce_shared_expert_outputs():
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def apply_routed_input_transform(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
Returns (possibly transformed) hidden states and the input for shared
experts (or None if there are no shared experts).
"""
if self.routed_input_transform is not None:
result = self.routed_input_transform(hidden_states)
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
if isinstance(result, tuple):
return result[0], hidden_states
return result, hidden_states
return (
hidden_states,
hidden_states if self._shared_experts is not None else None,
)
def _maybe_reduce_output(
self,
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
trunc_sizes: list[int],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return x[..., :trunc_size]
def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)
if (
not self.moe_config.is_sequence_parallel
and self.reduce_results
and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
):
func = reduce_and_trunc
else:
func = trunc
if isinstance(states, tuple):
return tuple(
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
)
else:
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])
def _encode_layer_name(self) -> str | ModuleName:
if HAS_OPAQUE_TYPE:
return ModuleName(self.layer_name)
# Can be unavailable or None in unittests
if (
is_forward_context_available()
and get_forward_context().all_moe_layers is not None
):
return "from_forward_context"
return self.layer_name
def _maybe_pad_hidden_states(
self,
shared_experts_input: torch.Tensor | None,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, list[int]]:
shared_experts_hidden_dim = (
shared_experts_input.shape[-1] if shared_experts_input is not None else 0
)
transformed_hidden_dim = hidden_states.shape[-1]
if (
not self.quant_method.skip_forward_padding
and self.moe_config.hidden_dim != transformed_hidden_dim
):
hidden_states = F.pad(
hidden_states,
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
mode="constant",
value=0.0,
)
if self._shared_experts is not None:
orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim]
else:
orig_hidden_dims = [transformed_hidden_dim]
return hidden_states, orig_hidden_dims
def _maybe_apply_shared_experts(
self,
shared_experts_input: torch.Tensor | None,
order: SharedExpertsOrder,
):
if self._shared_experts is not None:
assert shared_experts_input is not None
self._shared_experts.apply(shared_experts_input, order)
def _apply_quant_method(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> tuple[torch.Tensor | None, torch.Tensor]:
# Run this before quant_method to avoid inplace issues.
# TODO(bnell): probably not needed anymore since inplace is
# disabled when shared experts are present.
self._maybe_apply_shared_experts(
shared_experts_input, SharedExpertsOrder.NO_OVERLAP
)
if self.quant_method.is_monolithic:
fused_out = self.quant_method.apply_monolithic(
layer=layer,
x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
# Passing shared_experts_input in case SharedExpertsOrder is
# NO_OVERLAP or MK_INTERNAL_OVERLAPPED.
fused_out = self.quant_method.apply(
layer=layer,
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_experts_input,
)
self._maybe_apply_shared_experts(
shared_experts_input,
SharedExpertsOrder.MULTI_STREAM_OVERLAPPED,
)
return (
self._shared_experts.output if self._shared_experts is not None else None,
fused_out,
)
def _sequence_parallel_context(self):
ctx = get_forward_context()
return (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
if ctx.dp_metadata
else nullcontext()
)
def _maybe_sync_shared_experts_stream(
self,
shared_experts_input: torch.Tensor | None,
):
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self._shared_experts is not None:
self._shared_experts.maybe_sync_shared_experts_stream(shared_experts_input)
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Invoke the fused moe layer.
Input:
- hidden_states
- router_logits
Output:
- The new hidden_states.
or
- A tuple of (shared experts output, new hidden_states).
Calling sequence
- forward
- self.forward_entry (_moe_forward or _moe_forward_shared custom op)
- forward_dispatch
- _forward_impl
Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
to the following reasons:
1. the chunking loop in ChunkingMoERunner._forward_impl cannot be compiled by
torch.compile
2. pytorch cannot handle union types in custom op signatures so _moe_forward
and _moe_forward_shared must be split.
If ChunkingMoERunner._forward_impl can be implemented via torch.scan we can
potentially get rid of _moe_forward and _moe_forward_shared and collapse the
whole sequence into the 'forward' method.
"""
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states, shared_experts_input = self.apply_routed_input_transform(
hidden_states
)
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
shared_experts_input,
hidden_states,
)
fused_output = self.forward_entry(
hidden_states,
router_logits,
shared_experts_input,
self._encode_layer_name(),
)
return self._maybe_reduce_output(fused_output, og_hidden_dims)
def forward_dispatch(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
# Sync aux and main stream for shared expert multi-stream overlap.
self._maybe_sync_shared_experts_stream(shared_experts_input)
# If the Runner holds the gate, apply it after the stream sync,
# so it can run overlapped with the
# NOTE: in future PR, MoE runner will always hold the gate.
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
with self._sequence_parallel_context():
return self._forward_impl(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
@abstractmethod
def _forward_impl(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.runner.chunking_moe_runner import (
ChunkingMoERunner,
)
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
DefaultMoERunner,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
def create_moe_runner(
layer_name: str,
moe_config: FusedMoEConfig,
router: FusedMoERouter,
routed_input_transform: torch.nn.Module | None,
gate: torch.nn.Module | None,
shared_experts: SharedExperts | None,
quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool,
) -> MoERunner:
runner = DefaultMoERunner(
layer_name,
moe_config,
router,
routed_input_transform,
gate,
shared_experts,
quant_method,
reduce_results,
enable_dbo,
)
if moe_config.moe_parallel_config.use_dp_chunking:
return ChunkingMoERunner(runner)
return runner

View File

@@ -32,19 +32,14 @@ class SharedExpertsOrder(IntEnum):
# No shared experts.
NONE = (0,)
# Get rid of this one? combine with BEFORE?
# Note: this might be important for torch.compile reasons. Can
# get rid of it after _moe_forward is undone.
EXTERNAL = (1,)
# No overlap - defensively called before MK.
NO_OVERLAP = (2,)
NO_OVERLAP = (1,)
# Overlapped with dispatch/combine in DP/EP - called by the MK.
MK_INTERNAL_OVERLAPPED = (3,)
MK_INTERNAL_OVERLAPPED = (2,)
# Overlapped with the gate, router, experts in aux stream.
MULTI_STREAM_OVERLAPPED = (4,)
MULTI_STREAM_OVERLAPPED = (3,)
class SharedExperts:
@@ -110,9 +105,6 @@ class SharedExperts:
self,
hidden_states: torch.Tensor,
) -> SharedExpertsOrder:
if self._use_external_experts:
return SharedExpertsOrder.EXTERNAL
if self._quant_method.mk_owns_shared_expert:
return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED
@@ -205,12 +197,4 @@ class SharedExperts:
else:
self._output[self._output_idx] = self._layer(shared_experts_input)
if order == SharedExpertsOrder.EXTERNAL:
# TODO: figure out how to combine this with maybe_reduce_output?
# or get rid of it completely.
assert self._output[self._output_idx] is not None
self._output[self._output_idx] = self._maybe_reduce_shared_out(
self._output[self._output_idx]
)
assert self._output[self._output_idx] is not None

View File

@@ -231,6 +231,7 @@ class NemotronHMoE(nn.Module):
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routed_input_transform=self.fc1_latent_proj,
router_logits_dtype=self.gate.out_dtype,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: