[MoE Refactor] Split of DefaultMoERunner class (#35326)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -39,8 +39,8 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.router.router_factory import (
|
||||
create_fused_moe_router,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
|
||||
DefaultMoERunner,
|
||||
from vllm.model_executor.layers.fused_moe.runner.moe_runner_factory import (
|
||||
create_moe_runner,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||
SharedExperts,
|
||||
@@ -572,8 +572,8 @@ class FusedMoE(CustomOp):
|
||||
# Storing the runner in the FusedMoE is an intermediate state, eventually
|
||||
# the runner will own the FusedMoE layer and provide the execution interface
|
||||
# for MoE ops.
|
||||
self.runner = DefaultMoERunner(
|
||||
layer=self,
|
||||
self.runner = create_moe_runner(
|
||||
layer_name=self.layer_name,
|
||||
moe_config=self.moe_config,
|
||||
router=self.router,
|
||||
routed_input_transform=self._routed_input_transform,
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import (
|
||||
get_forward_context,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase
|
||||
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||
SharedExperts,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
|
||||
class ChunkingMoERunner(MoERunnerBase):
|
||||
"""
|
||||
MoE runner wrapper that adds chunked processing to any MoERunnerBase.
|
||||
|
||||
This runner wraps an inner MoERunnerBase and overrides _forward_impl to
|
||||
process large batches by breaking them into smaller chunks. Each chunk
|
||||
is delegated to the inner runner's _forward_impl, making chunking
|
||||
composable with any runner implementation.
|
||||
|
||||
All MoERunnerBase state (moe_config, router, quant_method, etc.) is
|
||||
transparently delegated to the inner runner via __getattr__.
|
||||
ChunkingMoERunner only owns chunking-specific state: the pre-allocated
|
||||
workspace buffers and the reduce_results override.
|
||||
|
||||
Key behaviors:
|
||||
- Pre-allocates workspace tensors for CUDA graph compatibility
|
||||
- Processes chunks via inner._forward_impl per chunk
|
||||
- Never reduces results (reduce_results always returns False)
|
||||
"""
|
||||
|
||||
def __init__(self, inner: MoERunnerBase):
|
||||
# Assert that _maybe_dispatch/_maybe_combine will be nops.
|
||||
assert inner.moe_config.pcp_size == 1
|
||||
|
||||
# Skip MoERunnerBase.__init__ — all state is delegated to inner
|
||||
# via __getattr__. Only chunking-specific state lives here.
|
||||
self._inner = inner
|
||||
|
||||
# Pre-allocated staging buffers. These need to exist ahead of time
|
||||
# due to CUDA graph construction needing fixed buffer addresses.
|
||||
self.batched_hidden_states, self.batched_router_logits = (
|
||||
self._init_dp_chunking()
|
||||
)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Delegate attribute access to the inner runner. This is only
|
||||
# called when normal lookup (instance __dict__, class MRO) fails,
|
||||
# so ChunkingMoERunner's own attributes and methods take priority.
|
||||
return getattr(self._inner, name)
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> SharedExperts | None:
|
||||
return self._inner.shared_experts
|
||||
|
||||
# TODO(bnell): temporary hack, do not call this method.
|
||||
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
|
||||
self._inner._replace_quant_method(quant_method)
|
||||
self.quant_method = quant_method
|
||||
|
||||
def is_internal_router(self) -> bool:
|
||||
return self._inner.gate is not None
|
||||
|
||||
# Reducing results when chunking is handled by the MK finalize operations
|
||||
# when DP chunking is enabled..
|
||||
# This will be removed by #35949
|
||||
@property
|
||||
def reduce_results(self) -> bool:
|
||||
return False
|
||||
|
||||
def _init_dp_chunking(self) -> list[torch.Tensor]:
|
||||
states_shape: tuple[int, ...]
|
||||
logits_shape: tuple[int, ...]
|
||||
|
||||
moe = self.moe_config
|
||||
|
||||
if self.enable_dbo:
|
||||
states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
|
||||
logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
|
||||
else:
|
||||
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
|
||||
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
|
||||
|
||||
# Does this need some kind of profiling run check like modular_kernel.py?
|
||||
return current_workspace_manager().get_simultaneous(
|
||||
(states_shape, moe.in_dtype),
|
||||
(logits_shape, moe.router_logits_dtype),
|
||||
)
|
||||
|
||||
def _allocate_dp_chunking_outputs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||
# Assert the inputs are of the proper type and shape.
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
|
||||
assert self.batched_hidden_states.dtype == hidden_states.dtype, (
|
||||
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
|
||||
)
|
||||
assert self.batched_router_logits.dtype == router_logits.dtype, (
|
||||
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
|
||||
)
|
||||
|
||||
# Check size compatibility.
|
||||
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
|
||||
assert self.batched_router_logits.size(-1) == router_logits.size(-1)
|
||||
|
||||
final_fused_hidden_states = torch.empty_like(hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
if shared_experts_input is not None:
|
||||
final_shared_hidden_states = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
final_shared_hidden_states = torch.empty_like(hidden_states)
|
||||
else:
|
||||
final_shared_hidden_states = None
|
||||
|
||||
return final_shared_hidden_states, final_fused_hidden_states
|
||||
|
||||
def _slice_and_copy_input(
|
||||
self,
|
||||
out_slice: torch.Tensor,
|
||||
orig: torch.Tensor | None,
|
||||
start: int,
|
||||
end: int,
|
||||
) -> torch.Tensor:
|
||||
assert orig is not None
|
||||
slice_size = end - start
|
||||
orig_slice = orig[start:end, :]
|
||||
if self.enable_dbo:
|
||||
assert out_slice.dim() == 3
|
||||
batch_buffer_idx = dbo_current_ubatch_id()
|
||||
out_slice = out_slice[batch_buffer_idx, :]
|
||||
|
||||
assert out_slice.size(0) >= slice_size
|
||||
out_slice = out_slice[:slice_size, :]
|
||||
out_slice.copy_(orig_slice, non_blocking=True)
|
||||
return out_slice
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
final_shared_hidden_states, final_fused_hidden_states = (
|
||||
self._allocate_dp_chunking_outputs(
|
||||
hidden_states, router_logits, shared_experts_input
|
||||
)
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
# If the input to the MoE is sequence parallel then divide by sp_size
|
||||
# to find the maximum number of tokens for any individual dispatcher.
|
||||
if self.moe_config.is_sequence_parallel:
|
||||
max_tokens_across_dispatchers = cdiv(
|
||||
max_tokens_across_dispatchers, self.moe_config.sp_size
|
||||
)
|
||||
|
||||
num_tokens = hidden_states.size(0)
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
|
||||
):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(
|
||||
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
|
||||
)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
chunk_sizes = ctx.dp_metadata.chunked_sizes(
|
||||
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
|
||||
)
|
||||
with chunk_sizes:
|
||||
hidden_states_chunk = self._slice_and_copy_input(
|
||||
self.batched_hidden_states,
|
||||
hidden_states,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
)
|
||||
|
||||
router_logits_chunk = self._slice_and_copy_input(
|
||||
self.batched_router_logits,
|
||||
router_logits,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
)
|
||||
|
||||
shared_experts_input_chunk = (
|
||||
shared_experts_input[chunk_start:chunk_end, :]
|
||||
if shared_experts_input is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Delegate per-chunk computation to the inner runner.
|
||||
chunk_result = self._inner._forward_impl(
|
||||
layer=layer,
|
||||
hidden_states=hidden_states_chunk,
|
||||
router_logits=router_logits_chunk,
|
||||
shared_experts_input=shared_experts_input_chunk,
|
||||
)
|
||||
|
||||
# Store outputs
|
||||
# TODO(bnell): document when chunk_start >= num_tokens
|
||||
if chunk_start < num_tokens:
|
||||
if self.shared_experts is not None:
|
||||
assert isinstance(chunk_result, tuple)
|
||||
shared_output_chunk, hidden_states_chunk = chunk_result
|
||||
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
hidden_states_chunk, non_blocking=True
|
||||
)
|
||||
assert shared_output_chunk is not None
|
||||
assert final_shared_hidden_states is not None
|
||||
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
shared_output_chunk, non_blocking=True
|
||||
)
|
||||
else:
|
||||
assert isinstance(chunk_result, torch.Tensor)
|
||||
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
chunk_result, non_blocking=True
|
||||
)
|
||||
|
||||
if self.shared_experts is None:
|
||||
return final_fused_hidden_states
|
||||
else:
|
||||
assert final_shared_hidden_states is not None
|
||||
return (final_shared_hidden_states, final_fused_hidden_states)
|
||||
@@ -1,516 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.forward_context import (
|
||||
ForwardContext,
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
||||
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||
SharedExperts,
|
||||
SharedExpertsOrder,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import (
|
||||
HAS_OPAQUE_TYPE,
|
||||
ModuleName,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase
|
||||
|
||||
|
||||
def get_layer_from_name(layer_name: str) -> torch.nn.Module:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if layer_name == "from_forward_context":
|
||||
all_moe_layers = forward_context.all_moe_layers
|
||||
assert all_moe_layers is not None
|
||||
moe_layer_index = forward_context.moe_layer_index
|
||||
if moe_layer_index >= len(all_moe_layers):
|
||||
raise AssertionError(
|
||||
"We expected the number of MOE layers in `all_moe_layers` "
|
||||
"to be equal to the number of "
|
||||
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
|
||||
)
|
||||
layer_name = all_moe_layers[moe_layer_index]
|
||||
forward_context.moe_layer_index += 1
|
||||
return forward_context.no_compile_layers[layer_name]
|
||||
|
||||
|
||||
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
|
||||
# on older versions it remains a plain str.
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
_layer_name_type: TypeAlias = str | ModuleName
|
||||
else:
|
||||
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
|
||||
|
||||
|
||||
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
|
||||
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
|
||||
|
||||
|
||||
# Note: _moe_forward and _moe_forward_shared should not contain any
|
||||
# implementation details, They should merely pass along control to
|
||||
# the runner's 'forward_dispatch' method.
|
||||
def _moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
return layer.runner.forward_dispatch(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
def _moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def _moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
return layer.runner.forward_dispatch(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
def _moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Output shapes:
|
||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||
# - shared_out: same as shared_experts_input if provided, else same as
|
||||
# hidden_states
|
||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
if shared_experts_input is not None:
|
||||
shared_out = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward",
|
||||
op_func=_moe_forward,
|
||||
mutates_args=["hidden_states"], # is this still true?
|
||||
fake_impl=_moe_forward_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward_shared",
|
||||
op_func=_moe_forward_shared,
|
||||
fake_impl=_moe_forward_shared_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
class DefaultMoERunner(MoERunner):
|
||||
class DefaultMoERunner(MoERunnerBase):
|
||||
"""
|
||||
Default implementation of the MoE runner for executing Mixture of Experts layers.
|
||||
Standard MoE runner implementation for executing Mixture of Experts layers.
|
||||
|
||||
This class provides a comprehensive implementation for running MoE computations
|
||||
with support for:
|
||||
- Expert routing and token dispatching
|
||||
This is the primary concrete implementation of MoE execution logic, providing
|
||||
comprehensive support for standard MoE operations. It handles:
|
||||
- Expert routing and token dispatching using various routing strategies
|
||||
- Shared experts computation with optional parallel execution using CUDA streams
|
||||
- Data parallel (DP) chunking for large batch processing
|
||||
- Tensor model parallel and expert parallel operations
|
||||
- Various quantization methods and custom operators
|
||||
- Multiple quantization methods and optimized kernel selection
|
||||
- Both monolithic and decomposed expert execution paths
|
||||
- Integration with various parallel execution modes (TP, EP, DP)
|
||||
|
||||
The runner handles the complete MoE forward pass including routing tokens to
|
||||
experts, executing expert computations, and combining results. It supports
|
||||
advanced features like overlapped execution of shared experts and optimized
|
||||
kernels for different parallel execution modes.
|
||||
The runner orchestrates the complete MoE forward pass including routing tokens
|
||||
to experts, executing expert computations in parallel, and combining results.
|
||||
It supports advanced features like overlapped execution of shared experts,
|
||||
optimized kernels for different parallel configurations, and seamless
|
||||
integration with vLLM's distributed execution framework.
|
||||
|
||||
Eventually, this class will be split up and specialized for different
|
||||
configurations, e.g. the presence or absence of shared experts, a gate, etc.
|
||||
This implementation is suitable for most standard MoE use cases. For specialized
|
||||
scenarios like large batch chunking, alternative runners like ChunkingMoERunner
|
||||
may be more appropriate.
|
||||
|
||||
Eventually, this class may be split into more specialized implementations
|
||||
for different configurations (e.g., with/without shared experts, gates, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
moe_config: FusedMoEConfig,
|
||||
router: FusedMoERouter,
|
||||
routed_input_transform: torch.nn.Module | None,
|
||||
gate: torch.nn.Module | None,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
quant_method: FusedMoEMethodBase,
|
||||
reduce_results: bool,
|
||||
enable_dbo: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.moe_config = moe_config
|
||||
self.router = router
|
||||
self.routed_input_transform = routed_input_transform
|
||||
self.gate = gate
|
||||
self.quant_method = quant_method
|
||||
self.reduce_results = reduce_results
|
||||
self.enable_dbo = enable_dbo
|
||||
|
||||
self.shared_experts: SharedExperts | None = None
|
||||
if shared_experts is not None:
|
||||
self.shared_experts = SharedExperts(
|
||||
shared_experts,
|
||||
moe_config=moe_config,
|
||||
# Note: For now we must pass quant_method along to SharedExperts so it
|
||||
# can property determine where the shared experts are supposed to be
|
||||
# called, i.e. by a MK or by the MoERunner.
|
||||
# Once the MK can be created upfront, we can just pass in the proper
|
||||
# flags derived from the quant_method's MK.
|
||||
reduce_results=reduce_results,
|
||||
quant_method=quant_method,
|
||||
enable_dbo=enable_dbo,
|
||||
)
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
# These need to exist ahead of time due to CUDAgraph construction
|
||||
# needing a fixed buffer address.
|
||||
self.use_dp_chunking = self.moe_config.moe_parallel_config.use_dp_chunking
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
self._maybe_init_dp_chunking()
|
||||
|
||||
# Needed for string -> FusedMoE layer lookup in custom ops.
|
||||
self.layer_name = layer.layer_name
|
||||
|
||||
self.forward_entry, self.forward_impl = self._select_forward(layer)
|
||||
|
||||
def _select_forward(self, layer: torch.nn.Module) -> tuple[Callable, Callable]:
|
||||
# Select implementation based on presence of DP chunking.
|
||||
forward_impl_fn = (
|
||||
self._forward_impl_chunked if self.use_dp_chunking else self._forward_impl
|
||||
)
|
||||
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped forward_impl.
|
||||
return (
|
||||
_moe_forward if self.shared_experts is None else _moe_forward_shared,
|
||||
forward_impl_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
torch.ops.vllm.moe_forward
|
||||
if self.shared_experts is None
|
||||
else torch.ops.vllm.moe_forward_shared,
|
||||
forward_impl_fn,
|
||||
)
|
||||
|
||||
# TODO(bnell): temporary hack, do not call this method.
|
||||
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
|
||||
if self.shared_experts is not None:
|
||||
self.shared_experts._quant_method = quant_method
|
||||
self.quant_method = quant_method
|
||||
|
||||
def is_internal_router(self) -> bool:
|
||||
return self.gate is not None
|
||||
|
||||
def _maybe_init_dp_chunking(self):
|
||||
if not self.use_dp_chunking:
|
||||
return
|
||||
|
||||
assert self.batched_hidden_states is None
|
||||
states_shape: tuple[int, ...]
|
||||
logits_shape: tuple[int, ...]
|
||||
|
||||
moe = self.moe_config
|
||||
|
||||
if self.enable_dbo:
|
||||
states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
|
||||
logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
|
||||
else:
|
||||
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
|
||||
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
|
||||
|
||||
device = torch.accelerator.current_device_index()
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
states_shape,
|
||||
dtype=moe.in_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
logits_shape,
|
||||
dtype=moe.router_logits_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
The shared_experts are typically computed using the RowParallelLinear
|
||||
layer. The result of this function is typically used as
|
||||
the reduce_results argument to the module.
|
||||
When just tensor-parallel is used, it is not required to reduce
|
||||
the shared_experts results immediately. Instead we reduce at the
|
||||
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
|
||||
With EP and all2all kernels - this is no longer viable as all
|
||||
GPU ranks in DP, produce the complete set of hidden_states.
|
||||
Therefore it is required that we reduce the shared_experts output
|
||||
early.
|
||||
"""
|
||||
return (
|
||||
self.quant_method.moe_kernel is not None
|
||||
and self.quant_method.moe_kernel.output_is_reduced()
|
||||
)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
"""
|
||||
Some combine kernels reduce across GPU ranks by default.
|
||||
"""
|
||||
if self.must_reduce_shared_expert_outputs():
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def apply_routed_input_transform(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Apply transform for routed experts (e.g., latent projection).
|
||||
|
||||
This is called by FusedMoE.forward_native. The original hidden_states
|
||||
is saved separately so shared experts get [S, hidden_size] while
|
||||
routed experts get the transformed [S, moe_latent_size].
|
||||
|
||||
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
|
||||
moved inside SharedFusedMoE to all-reduce on the smaller latent
|
||||
dimension.
|
||||
|
||||
Returns (possibly transformed) hidden states and the input for shared
|
||||
experts (or None if there are no shared experts).
|
||||
"""
|
||||
if self.routed_input_transform is not None:
|
||||
result = self.routed_input_transform(hidden_states)
|
||||
# ReplicatedLinear returns (output, extra_bias) tuple.
|
||||
# We only need the output tensor; extra_bias is not used here.
|
||||
if isinstance(result, tuple):
|
||||
return result[0], hidden_states
|
||||
return result, hidden_states
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
hidden_states if self.shared_experts is not None else None,
|
||||
)
|
||||
|
||||
def _maybe_reduce_output(
|
||||
self,
|
||||
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
trunc_sizes: list[int],
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
|
||||
return x[..., :trunc_size]
|
||||
|
||||
def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
|
||||
return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)
|
||||
|
||||
if (
|
||||
not self.moe_config.is_sequence_parallel
|
||||
and not self.use_dp_chunking
|
||||
and self.reduce_results
|
||||
and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
|
||||
):
|
||||
func = reduce_and_trunc
|
||||
else:
|
||||
func = trunc
|
||||
|
||||
if isinstance(states, tuple):
|
||||
return tuple(
|
||||
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
|
||||
)
|
||||
else:
|
||||
assert len(trunc_sizes) == 1
|
||||
return func(states, trunc_sizes[0])
|
||||
|
||||
def _encode_layer_name(self) -> str | ModuleName:
|
||||
if HAS_OPAQUE_TYPE:
|
||||
return ModuleName(self.layer_name)
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().all_moe_layers is not None
|
||||
):
|
||||
return "from_forward_context"
|
||||
return self.layer_name
|
||||
|
||||
def _maybe_pad_hidden_states(
|
||||
self,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
shared_experts_hidden_dim = (
|
||||
shared_experts_input.shape[-1] if shared_experts_input is not None else 0
|
||||
)
|
||||
transformed_hidden_dim = hidden_states.shape[-1]
|
||||
if (
|
||||
not self.quant_method.skip_forward_padding
|
||||
and self.moe_config.hidden_dim != transformed_hidden_dim
|
||||
):
|
||||
hidden_states = F.pad(
|
||||
hidden_states,
|
||||
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim]
|
||||
else:
|
||||
orig_hidden_dims = [transformed_hidden_dim]
|
||||
|
||||
return hidden_states, orig_hidden_dims
|
||||
|
||||
def _maybe_apply_shared_experts(
|
||||
self,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
order: SharedExpertsOrder,
|
||||
):
|
||||
if self.shared_experts is not None:
|
||||
assert shared_experts_input is not None
|
||||
self.shared_experts.apply(shared_experts_input, order)
|
||||
|
||||
def _apply_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||
# Run this before quant_method to avoid inplace issues.
|
||||
# TODO(bnell): probably not needed anymore since inplace is
|
||||
# disabled when shared experts are present.
|
||||
self._maybe_apply_shared_experts(
|
||||
shared_experts_input, SharedExpertsOrder.NO_OVERLAP
|
||||
)
|
||||
|
||||
if self.quant_method.is_monolithic:
|
||||
fused_out = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
# Passing shared_experts_input in case SharedExpertsOrder is
|
||||
# NO_OVERLAP or MK_INTERNAL_OVERLAPPED.
|
||||
fused_out = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
self._maybe_apply_shared_experts(
|
||||
shared_experts_input,
|
||||
SharedExpertsOrder.MULTI_STREAM_OVERLAPPED,
|
||||
)
|
||||
|
||||
return (
|
||||
self.shared_experts.output if self.shared_experts is not None else None,
|
||||
fused_out,
|
||||
)
|
||||
|
||||
def _sequence_parallel_context(self):
|
||||
ctx = get_forward_context()
|
||||
return (
|
||||
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
|
||||
if ctx.dp_metadata
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
def _allocate_dp_chunking_outputs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||
assert self.use_dp_chunking
|
||||
|
||||
# Assert the inputs are of the proper type and shape.
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
|
||||
assert self.batched_hidden_states.dtype == hidden_states.dtype, (
|
||||
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
|
||||
)
|
||||
assert self.batched_router_logits.dtype == router_logits.dtype, (
|
||||
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
|
||||
)
|
||||
|
||||
# Check size compatibility.
|
||||
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
|
||||
assert self.batched_router_logits.size(-1) == router_logits.size(-1)
|
||||
|
||||
final_fused_hidden_states = torch.empty_like(hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
final_shared_hidden_states = torch.empty_like(hidden_states)
|
||||
else:
|
||||
final_shared_hidden_states = None
|
||||
|
||||
return final_shared_hidden_states, final_fused_hidden_states
|
||||
|
||||
def _maybe_sync_shared_experts_stream(
|
||||
self,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
):
|
||||
# If router/gate provided, then apply it here.
|
||||
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||
# parallel execution of shared experts with the FusedMoE via
|
||||
# separate cuda stream)
|
||||
if self.shared_experts is not None:
|
||||
self.shared_experts.maybe_sync_shared_experts_stream(shared_experts_input)
|
||||
@property
|
||||
def reduce_results(self) -> bool:
|
||||
return self._reduce_results
|
||||
|
||||
@property
|
||||
def do_naive_dispatch_combine(self) -> bool:
|
||||
@@ -572,195 +101,6 @@ class DefaultMoERunner(MoERunner):
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Invoke the fused moe layer.
|
||||
|
||||
Input:
|
||||
- hidden_states
|
||||
- router_logits
|
||||
|
||||
Output:
|
||||
- The new hidden_states.
|
||||
or
|
||||
- A tuple of (shared experts output, new hidden_states).
|
||||
|
||||
Calling sequence
|
||||
- forward
|
||||
- self.forward_entry (_moe_forward or _moe_forward_shared custom op)
|
||||
- forward_dispatch
|
||||
- forward_impl (_forward_impl or _forward_impl_chunked)
|
||||
|
||||
Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
|
||||
to the following reasons:
|
||||
1. the chunking loop in _forward_impl_chunked cannot be compiled by
|
||||
torch.compile
|
||||
2. pytorch cannot handle union types in custom op signatures so _moe_forward
|
||||
and _moe_forward_shared must be split.
|
||||
|
||||
If _forward_impl_chunked can be implemented via torch.scan we can potentially
|
||||
get rid of _moe_forward and _moe_forward_shared and collapse the whole sequence
|
||||
into the 'forward' method.
|
||||
"""
|
||||
|
||||
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||
hidden_states, shared_experts_input = self.apply_routed_input_transform(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
|
||||
shared_experts_input,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
fused_output = self.forward_entry(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
self._encode_layer_name(),
|
||||
)
|
||||
|
||||
return self._maybe_reduce_output(fused_output, og_hidden_dims)
|
||||
|
||||
def forward_dispatch(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
|
||||
# Sync aux and main stream for shared expert multi-stream overlap.
|
||||
self._maybe_sync_shared_experts_stream(shared_experts_input)
|
||||
|
||||
# If the Runner holds the gate, apply it after the stream sync,
|
||||
# so it can run overlapped with the
|
||||
# NOTE: in future PR, MoE runner will always hold the gate.
|
||||
if self.gate is not None:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
self._maybe_apply_shared_experts(
|
||||
shared_experts_input,
|
||||
SharedExpertsOrder.EXTERNAL,
|
||||
)
|
||||
|
||||
with self._sequence_parallel_context():
|
||||
return self.forward_impl(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
|
||||
def _slice_and_copy_input(
|
||||
self,
|
||||
out_slice: torch.Tensor,
|
||||
orig: torch.Tensor | None,
|
||||
start: int,
|
||||
end: int,
|
||||
) -> torch.Tensor:
|
||||
assert orig is not None
|
||||
slice_size = end - start
|
||||
orig_slice = orig[start:end, :]
|
||||
if self.enable_dbo:
|
||||
assert out_slice.dim() == 3
|
||||
batch_buffer_idx = dbo_current_ubatch_id()
|
||||
out_slice = out_slice[batch_buffer_idx, :]
|
||||
|
||||
assert out_slice.size(0) >= slice_size
|
||||
out_slice = out_slice[:slice_size, :]
|
||||
out_slice.copy_(orig_slice, non_blocking=True)
|
||||
return out_slice
|
||||
|
||||
def _forward_impl_chunked(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
final_shared_hidden_states, final_fused_hidden_states = (
|
||||
self._allocate_dp_chunking_outputs(hidden_states, router_logits)
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
# If the input to the MoE is sequence parallel then divide by sp_size
|
||||
# to find the maximum number of tokens for any individual dispatcher.
|
||||
if self.moe_config.is_sequence_parallel:
|
||||
max_tokens_across_dispatchers = cdiv(
|
||||
max_tokens_across_dispatchers, self.moe_config.sp_size
|
||||
)
|
||||
|
||||
num_tokens = hidden_states.size(0)
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
|
||||
):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(
|
||||
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
|
||||
)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
chunk_sizes = ctx.dp_metadata.chunked_sizes(
|
||||
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
|
||||
)
|
||||
with chunk_sizes:
|
||||
hidden_states_chunk = self._slice_and_copy_input(
|
||||
self.batched_hidden_states,
|
||||
hidden_states,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
)
|
||||
|
||||
router_logits_chunk = self._slice_and_copy_input(
|
||||
self.batched_router_logits,
|
||||
router_logits,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
)
|
||||
|
||||
shared_experts_input_chunk = (
|
||||
shared_experts_input[chunk_start:chunk_end, :]
|
||||
if shared_experts_input is not None
|
||||
else None
|
||||
)
|
||||
|
||||
shared_output_chunk, hidden_states_chunk = self._apply_quant_method(
|
||||
layer=layer,
|
||||
hidden_states=hidden_states_chunk,
|
||||
router_logits=router_logits_chunk,
|
||||
shared_experts_input=shared_experts_input_chunk,
|
||||
)
|
||||
|
||||
# Store outputs
|
||||
# TODO(bnell): document when chunk_start >= num_tokens
|
||||
if chunk_start < num_tokens:
|
||||
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
hidden_states_chunk, non_blocking=True
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
assert shared_output_chunk is not None
|
||||
assert final_shared_hidden_states is not None
|
||||
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
shared_output_chunk, non_blocking=True
|
||||
)
|
||||
|
||||
if self.shared_experts is None:
|
||||
return final_fused_hidden_states
|
||||
else:
|
||||
assert final_shared_hidden_states is not None
|
||||
return (final_shared_hidden_states, final_fused_hidden_states)
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@@ -4,6 +4,13 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||
SharedExperts,
|
||||
)
|
||||
|
||||
|
||||
class MoERunner(ABC):
|
||||
"""
|
||||
@@ -36,3 +43,13 @@ class MoERunner(ABC):
|
||||
@abstractmethod
|
||||
def is_internal_router(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def shared_experts(self) -> SharedExperts | None:
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO(bnell): temporary hack, do not call this method.
|
||||
@abstractmethod
|
||||
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
527
vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py
Normal file
527
vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.forward_context import (
|
||||
ForwardContext,
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
||||
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||
SharedExperts,
|
||||
SharedExpertsOrder,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
HAS_OPAQUE_TYPE,
|
||||
ModuleName,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
|
||||
def get_layer_from_name(layer_name: str) -> torch.nn.Module:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if layer_name == "from_forward_context":
|
||||
all_moe_layers = forward_context.all_moe_layers
|
||||
assert all_moe_layers is not None
|
||||
moe_layer_index = forward_context.moe_layer_index
|
||||
if moe_layer_index >= len(all_moe_layers):
|
||||
raise AssertionError(
|
||||
"We expected the number of MOE layers in `all_moe_layers` "
|
||||
"to be equal to the number of "
|
||||
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
|
||||
)
|
||||
layer_name = all_moe_layers[moe_layer_index]
|
||||
forward_context.moe_layer_index += 1
|
||||
return forward_context.no_compile_layers[layer_name]
|
||||
|
||||
|
||||
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
|
||||
# on older versions it remains a plain str.
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
_layer_name_type: TypeAlias = str | ModuleName
|
||||
else:
|
||||
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
|
||||
|
||||
|
||||
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
|
||||
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
|
||||
|
||||
|
||||
# Note: _moe_forward and _moe_forward_shared should not contain any
|
||||
# implementation details, They should merely pass along control to
|
||||
# the runner's 'forward_dispatch' method.
|
||||
def _moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
return layer.runner.forward_dispatch(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
def _moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def _moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
return layer.runner.forward_dispatch(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
def _moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Output shapes:
|
||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||
# - shared_out: same as shared_experts_input if provided, else same as
|
||||
# hidden_states
|
||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
if shared_experts_input is not None:
|
||||
shared_out = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward",
|
||||
op_func=_moe_forward,
|
||||
mutates_args=["hidden_states"], # is this still true?
|
||||
fake_impl=_moe_forward_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward_shared",
|
||||
op_func=_moe_forward_shared,
|
||||
fake_impl=_moe_forward_shared_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
class MoERunnerBase(MoERunner):
|
||||
"""
|
||||
Abstract base class providing common functionality for MoE runner implementations.
|
||||
|
||||
This class serves as the foundation for concrete MoE runner implementations by
|
||||
providing shared state management and common utilities. It handles:
|
||||
- Common initialization and configuration management
|
||||
- Shared expert output reduction logic for tensor parallel scenarios
|
||||
- Base methods for tensor model parallel reductions
|
||||
- Common properties and utility functions used across different runner types
|
||||
|
||||
Concrete subclasses must implement the abstract methods to define their specific
|
||||
execution strategies, such as standard execution, chunked processing, or other
|
||||
specialized approaches. The base class provides the infrastructure while
|
||||
allowing flexibility in the actual MoE computation implementation.
|
||||
|
||||
Key abstract methods that subclasses must implement:
|
||||
- reduce_results: Determines whether results should be reduced across ranks
|
||||
- _forward_impl: The core MoE computation logic specific to each runner type
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_name: str,
|
||||
moe_config: FusedMoEConfig,
|
||||
router: FusedMoERouter,
|
||||
routed_input_transform: torch.nn.Module | None,
|
||||
gate: torch.nn.Module | None,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
quant_method: FusedMoEMethodBase,
|
||||
reduce_results: bool,
|
||||
enable_dbo: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.moe_config = moe_config
|
||||
self.router = router
|
||||
self.routed_input_transform = routed_input_transform
|
||||
self.gate = gate
|
||||
self.quant_method = quant_method
|
||||
self._reduce_results = reduce_results
|
||||
self.enable_dbo = enable_dbo
|
||||
|
||||
self._shared_experts: SharedExperts | None = None
|
||||
if shared_experts is not None:
|
||||
self._shared_experts = SharedExperts(
|
||||
shared_experts,
|
||||
moe_config=moe_config,
|
||||
# Note: For now we must pass quant_method along to SharedExperts so it
|
||||
# can property determine where the shared experts are supposed to be
|
||||
# called, i.e. by a MK or by the MoERunner.
|
||||
# Once the MK can be created upfront, we can just pass in the proper
|
||||
# flags derived from the quant_method's MK.
|
||||
reduce_results=reduce_results,
|
||||
quant_method=quant_method,
|
||||
enable_dbo=enable_dbo,
|
||||
)
|
||||
|
||||
# Needed for string -> FusedMoE layer lookup in custom ops.
|
||||
self.layer_name = layer_name
|
||||
|
||||
self.forward_entry = self._select_forward()
|
||||
|
||||
def _select_forward(self) -> Callable:
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped _forward_impl.
|
||||
return _moe_forward if self._shared_experts is None else _moe_forward_shared
|
||||
|
||||
return (
|
||||
torch.ops.vllm.moe_forward
|
||||
if self._shared_experts is None
|
||||
else torch.ops.vllm.moe_forward_shared
|
||||
)
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> SharedExperts | None:
|
||||
return self._shared_experts
|
||||
|
||||
# TODO(bnell): temporary hack, do not call this method.
|
||||
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
|
||||
if self._shared_experts is not None:
|
||||
self._shared_experts._quant_method = quant_method
|
||||
self.quant_method = quant_method
|
||||
|
||||
def is_internal_router(self) -> bool:
|
||||
return self.gate is not None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def reduce_results(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
The shared_experts are typically computed using the RowParallelLinear
|
||||
layer. The result of this function is typically used as
|
||||
the reduce_results argument to the module.
|
||||
When just tensor-parallel is used, it is not required to reduce
|
||||
the shared_experts results immediately. Instead we reduce at the
|
||||
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
|
||||
With EP and all2all kernels - this is no longer viable as all
|
||||
GPU ranks in DP, produce the complete set of hidden_states.
|
||||
Therefore it is required that we reduce the shared_experts output
|
||||
early.
|
||||
"""
|
||||
return (
|
||||
self.quant_method.moe_kernel is not None
|
||||
and self.quant_method.moe_kernel.output_is_reduced()
|
||||
)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
"""
|
||||
Some combine kernels reduce across GPU ranks by default.
|
||||
"""
|
||||
if self.must_reduce_shared_expert_outputs():
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def apply_routed_input_transform(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Apply transform for routed experts (e.g., latent projection).
|
||||
|
||||
This is called by FusedMoE.forward_native. The original hidden_states
|
||||
is saved separately so shared experts get [S, hidden_size] while
|
||||
routed experts get the transformed [S, moe_latent_size].
|
||||
|
||||
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
|
||||
moved inside SharedFusedMoE to all-reduce on the smaller latent
|
||||
dimension.
|
||||
|
||||
Returns (possibly transformed) hidden states and the input for shared
|
||||
experts (or None if there are no shared experts).
|
||||
"""
|
||||
if self.routed_input_transform is not None:
|
||||
result = self.routed_input_transform(hidden_states)
|
||||
# ReplicatedLinear returns (output, extra_bias) tuple.
|
||||
# We only need the output tensor; extra_bias is not used here.
|
||||
if isinstance(result, tuple):
|
||||
return result[0], hidden_states
|
||||
return result, hidden_states
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
hidden_states if self._shared_experts is not None else None,
|
||||
)
|
||||
|
||||
def _maybe_reduce_output(
|
||||
self,
|
||||
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
trunc_sizes: list[int],
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
|
||||
return x[..., :trunc_size]
|
||||
|
||||
def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
|
||||
return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)
|
||||
|
||||
if (
|
||||
not self.moe_config.is_sequence_parallel
|
||||
and self.reduce_results
|
||||
and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
|
||||
):
|
||||
func = reduce_and_trunc
|
||||
else:
|
||||
func = trunc
|
||||
|
||||
if isinstance(states, tuple):
|
||||
return tuple(
|
||||
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
|
||||
)
|
||||
else:
|
||||
assert len(trunc_sizes) == 1
|
||||
return func(states, trunc_sizes[0])
|
||||
|
||||
def _encode_layer_name(self) -> str | ModuleName:
|
||||
if HAS_OPAQUE_TYPE:
|
||||
return ModuleName(self.layer_name)
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().all_moe_layers is not None
|
||||
):
|
||||
return "from_forward_context"
|
||||
return self.layer_name
|
||||
|
||||
def _maybe_pad_hidden_states(
|
||||
self,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
shared_experts_hidden_dim = (
|
||||
shared_experts_input.shape[-1] if shared_experts_input is not None else 0
|
||||
)
|
||||
transformed_hidden_dim = hidden_states.shape[-1]
|
||||
if (
|
||||
not self.quant_method.skip_forward_padding
|
||||
and self.moe_config.hidden_dim != transformed_hidden_dim
|
||||
):
|
||||
hidden_states = F.pad(
|
||||
hidden_states,
|
||||
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
if self._shared_experts is not None:
|
||||
orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim]
|
||||
else:
|
||||
orig_hidden_dims = [transformed_hidden_dim]
|
||||
|
||||
return hidden_states, orig_hidden_dims
|
||||
|
||||
def _maybe_apply_shared_experts(
|
||||
self,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
order: SharedExpertsOrder,
|
||||
):
|
||||
if self._shared_experts is not None:
|
||||
assert shared_experts_input is not None
|
||||
self._shared_experts.apply(shared_experts_input, order)
|
||||
|
||||
def _apply_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||
# Run this before quant_method to avoid inplace issues.
|
||||
# TODO(bnell): probably not needed anymore since inplace is
|
||||
# disabled when shared experts are present.
|
||||
self._maybe_apply_shared_experts(
|
||||
shared_experts_input, SharedExpertsOrder.NO_OVERLAP
|
||||
)
|
||||
|
||||
if self.quant_method.is_monolithic:
|
||||
fused_out = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
# Passing shared_experts_input in case SharedExpertsOrder is
|
||||
# NO_OVERLAP or MK_INTERNAL_OVERLAPPED.
|
||||
fused_out = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
self._maybe_apply_shared_experts(
|
||||
shared_experts_input,
|
||||
SharedExpertsOrder.MULTI_STREAM_OVERLAPPED,
|
||||
)
|
||||
|
||||
return (
|
||||
self._shared_experts.output if self._shared_experts is not None else None,
|
||||
fused_out,
|
||||
)
|
||||
|
||||
def _sequence_parallel_context(self):
|
||||
ctx = get_forward_context()
|
||||
return (
|
||||
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
|
||||
if ctx.dp_metadata
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
def _maybe_sync_shared_experts_stream(
|
||||
self,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
):
|
||||
# If router/gate provided, then apply it here.
|
||||
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||
# parallel execution of shared experts with the FusedMoE via
|
||||
# separate cuda stream)
|
||||
if self._shared_experts is not None:
|
||||
self._shared_experts.maybe_sync_shared_experts_stream(shared_experts_input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Invoke the fused moe layer.
|
||||
|
||||
Input:
|
||||
- hidden_states
|
||||
- router_logits
|
||||
|
||||
Output:
|
||||
- The new hidden_states.
|
||||
or
|
||||
- A tuple of (shared experts output, new hidden_states).
|
||||
|
||||
Calling sequence
|
||||
- forward
|
||||
- self.forward_entry (_moe_forward or _moe_forward_shared custom op)
|
||||
- forward_dispatch
|
||||
- _forward_impl
|
||||
|
||||
Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
|
||||
to the following reasons:
|
||||
1. the chunking loop in ChunkingMoERunner._forward_impl cannot be compiled by
|
||||
torch.compile
|
||||
2. pytorch cannot handle union types in custom op signatures so _moe_forward
|
||||
and _moe_forward_shared must be split.
|
||||
|
||||
If ChunkingMoERunner._forward_impl can be implemented via torch.scan we can
|
||||
potentially get rid of _moe_forward and _moe_forward_shared and collapse the
|
||||
whole sequence into the 'forward' method.
|
||||
"""
|
||||
|
||||
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||
hidden_states, shared_experts_input = self.apply_routed_input_transform(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
|
||||
shared_experts_input,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
fused_output = self.forward_entry(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
self._encode_layer_name(),
|
||||
)
|
||||
|
||||
return self._maybe_reduce_output(fused_output, og_hidden_dims)
|
||||
|
||||
def forward_dispatch(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
|
||||
# Sync aux and main stream for shared expert multi-stream overlap.
|
||||
self._maybe_sync_shared_experts_stream(shared_experts_input)
|
||||
|
||||
# If the Runner holds the gate, apply it after the stream sync,
|
||||
# so it can run overlapped with the
|
||||
# NOTE: in future PR, MoE runner will always hold the gate.
|
||||
if self.gate is not None:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
with self._sequence_parallel_context():
|
||||
return self._forward_impl(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _forward_impl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.chunking_moe_runner import (
|
||||
ChunkingMoERunner,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
|
||||
DefaultMoERunner,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
||||
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||
SharedExperts,
|
||||
)
|
||||
|
||||
|
||||
def create_moe_runner(
|
||||
layer_name: str,
|
||||
moe_config: FusedMoEConfig,
|
||||
router: FusedMoERouter,
|
||||
routed_input_transform: torch.nn.Module | None,
|
||||
gate: torch.nn.Module | None,
|
||||
shared_experts: SharedExperts | None,
|
||||
quant_method: FusedMoEMethodBase,
|
||||
reduce_results: bool,
|
||||
enable_dbo: bool,
|
||||
) -> MoERunner:
|
||||
runner = DefaultMoERunner(
|
||||
layer_name,
|
||||
moe_config,
|
||||
router,
|
||||
routed_input_transform,
|
||||
gate,
|
||||
shared_experts,
|
||||
quant_method,
|
||||
reduce_results,
|
||||
enable_dbo,
|
||||
)
|
||||
if moe_config.moe_parallel_config.use_dp_chunking:
|
||||
return ChunkingMoERunner(runner)
|
||||
return runner
|
||||
@@ -32,19 +32,14 @@ class SharedExpertsOrder(IntEnum):
|
||||
# No shared experts.
|
||||
NONE = (0,)
|
||||
|
||||
# Get rid of this one? combine with BEFORE?
|
||||
# Note: this might be important for torch.compile reasons. Can
|
||||
# get rid of it after _moe_forward is undone.
|
||||
EXTERNAL = (1,)
|
||||
|
||||
# No overlap - defensively called before MK.
|
||||
NO_OVERLAP = (2,)
|
||||
NO_OVERLAP = (1,)
|
||||
|
||||
# Overlapped with dispatch/combine in DP/EP - called by the MK.
|
||||
MK_INTERNAL_OVERLAPPED = (3,)
|
||||
MK_INTERNAL_OVERLAPPED = (2,)
|
||||
|
||||
# Overlapped with the gate, router, experts in aux stream.
|
||||
MULTI_STREAM_OVERLAPPED = (4,)
|
||||
MULTI_STREAM_OVERLAPPED = (3,)
|
||||
|
||||
|
||||
class SharedExperts:
|
||||
@@ -110,9 +105,6 @@ class SharedExperts:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> SharedExpertsOrder:
|
||||
if self._use_external_experts:
|
||||
return SharedExpertsOrder.EXTERNAL
|
||||
|
||||
if self._quant_method.mk_owns_shared_expert:
|
||||
return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED
|
||||
|
||||
@@ -205,12 +197,4 @@ class SharedExperts:
|
||||
else:
|
||||
self._output[self._output_idx] = self._layer(shared_experts_input)
|
||||
|
||||
if order == SharedExpertsOrder.EXTERNAL:
|
||||
# TODO: figure out how to combine this with maybe_reduce_output?
|
||||
# or get rid of it completely.
|
||||
assert self._output[self._output_idx] is not None
|
||||
self._output[self._output_idx] = self._maybe_reduce_shared_out(
|
||||
self._output[self._output_idx]
|
||||
)
|
||||
|
||||
assert self._output[self._output_idx] is not None
|
||||
|
||||
@@ -231,6 +231,7 @@ class NemotronHMoE(nn.Module):
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
routed_input_transform=self.fc1_latent_proj,
|
||||
router_logits_dtype=self.gate.out_dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user