[MoE Refactor] DefaultMoERunner simplifcation (#33049)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-03-19 15:07:44 -04:00
committed by GitHub
parent 7454096199
commit 9279c59a0e
2 changed files with 393 additions and 309 deletions

View File

@@ -504,6 +504,8 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = MoEActivation.from_str(activation)
# TODO(bnell): we should not have to create a router if the kernel is
# monolithic.
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from contextlib import nullcontext
from typing import TYPE_CHECKING
@@ -82,9 +83,22 @@ def _moe_forward(
layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
layer, hidden_states, router_logits, shared_experts_input
)
runner = layer.runner
with runner._sequence_parallel_context():
if runner.use_dp_chunking:
return runner.forward_impl_chunked(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
else:
return runner.forward_impl(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _moe_forward_fake(
@@ -105,9 +119,22 @@ def _moe_forward_shared(
layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
layer, hidden_states, router_logits, shared_experts_input
)
runner = layer.runner
with runner._sequence_parallel_context():
if runner.use_dp_chunking:
return runner.forward_impl_chunked(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
else:
return runner.forward_impl(
layer,
hidden_states,
router_logits,
shared_experts_input,
)
def _moe_forward_shared_fake(
@@ -191,10 +218,17 @@ class DefaultMoERunner(MoERunner):
self.reduce_results = reduce_results
self.enable_dbo = enable_dbo
# Chunked all2all staging tensor
# TODO(bnell) rename these?
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
self._maybe_init_dp_chunking()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
self.use_shared_experts_stream = False
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
self.shared_experts_stream = None
@@ -210,23 +244,20 @@ class DefaultMoERunner(MoERunner):
# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer.layer_name
self.moe_forward = self._select_forward(layer)
def _select_forward(self, layer: torch.nn.Module) -> Callable:
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
if self.shared_experts is None:
self.moe_forward = _moe_forward
else:
self.moe_forward = _moe_forward_shared
else:
if self.shared_experts is None:
self.moe_forward = torch.ops.vllm.moe_forward
else:
self.moe_forward = torch.ops.vllm.moe_forward_shared
return _moe_forward if self.shared_experts is None else _moe_forward_shared
# Chunked all2all staging tensor
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
return (
torch.ops.vllm.moe_forward
if self.shared_experts is None
else torch.ops.vllm.moe_forward_shared
)
@property
def use_dp_chunking(self) -> bool:
@@ -241,22 +272,8 @@ class DefaultMoERunner(MoERunner):
self,
hidden_states: torch.Tensor,
shared_input: torch.Tensor | None,
has_separate_shared_experts: bool,
use_chunked_impl: bool,
) -> tuple[bool, torch.Tensor | None]:
use_shared_experts_stream = (
current_platform.is_cuda()
and has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
shared_experts_input: torch.Tensor | None = None
if use_shared_experts_stream:
):
if self.use_shared_experts_stream:
assert self.shared_experts_stream is not None
assert self.moe_config.disable_inplace
@@ -278,12 +295,11 @@ class DefaultMoERunner(MoERunner):
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
return use_shared_experts_stream, shared_experts_input
def ensure_dp_chunking_init(self):
if not self.use_dp_chunking or self.batched_hidden_states is not None:
def _maybe_init_dp_chunking(self):
if not self.use_dp_chunking:
return
assert self.batched_hidden_states is None
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
@@ -309,6 +325,38 @@ class DefaultMoERunner(MoERunner):
device=device,
)
@property
def has_separate_shared_experts(self) -> bool:
return (
not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None
)
def _apply_shared_experts(
self,
hidden_states: torch.Tensor,
allow_streaming: bool = False,
) -> torch.Tensor | None:
shared_output: torch.Tensor | None = None
if self.has_separate_shared_experts:
assert self.shared_experts is not None
if self.use_shared_experts_stream and allow_streaming:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states)
current_stream().wait_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)
return shared_output
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
@@ -322,7 +370,6 @@ class DefaultMoERunner(MoERunner):
Therefore it is required that we reduce the shared_experts output
early.
"""
assert self.quant_method is not None
return (
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
@@ -357,7 +404,7 @@ class DefaultMoERunner(MoERunner):
return result
return hidden_states
def _reduce_output(
def _maybe_reduce_output(
self,
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
trunc_sizes: list[int],
@@ -397,23 +444,16 @@ class DefaultMoERunner(MoERunner):
return "from_forward_context"
return self.layer_name
def forward(
def _maybe_pad_hidden_states(
self,
original_hidden_states: torch.Tensor | None,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
if self.shared_experts is not None:
original_hidden_states = hidden_states
original_hidden_dim = hidden_states.shape[-1]
else:
original_hidden_states = None
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states)
# This is the dimension after transform (for routed expert output slicing)
) -> tuple[torch.Tensor, list[int]]:
original_hidden_dim = (
original_hidden_states.shape[-1]
if original_hidden_states is not None
else 0
)
transformed_hidden_dim = hidden_states.shape[-1]
if (
not self.quant_method.skip_forward_padding
@@ -426,6 +466,192 @@ class DefaultMoERunner(MoERunner):
value=0.0,
)
if self.shared_experts is not None:
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
else:
orig_hidden_dims = [transformed_hidden_dim]
return hidden_states, orig_hidden_dims
def _apply_quant_method(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
run_shared_experts_before: bool = True,
) -> tuple[torch.Tensor | None, torch.Tensor]:
shared_input = shared_input if shared_input is not None else hidden_states
shared_output: torch.Tensor | None = None
# Run this before quant_method to avoid inplace issues.
if run_shared_experts_before:
shared_output = self._apply_shared_experts(shared_input, False)
if self.quant_method.is_monolithic:
result = self.quant_method.apply_monolithic(
layer=layer,
x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
result = self.quant_method.apply(
layer=layer,
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
)
if isinstance(result, tuple):
assert shared_output is None
shared_output, hidden_states = result
else:
hidden_states = result
if not run_shared_experts_before and self.has_separate_shared_experts:
assert shared_output is None
shared_output = self._apply_shared_experts(shared_input, True)
return shared_output, hidden_states
def _sequence_parallel_context(self):
ctx = get_forward_context()
return (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
if ctx.dp_metadata
else nullcontext()
)
def _allocate_dp_chunking_outputs(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor | None, torch.Tensor]:
assert self.use_dp_chunking
# Assert the inputs are of the proper type and shape.
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == router_logits.size(-1)
final_fused_hidden_states = torch.empty_like(hidden_states)
if self.shared_experts is not None:
final_shared_hidden_states = torch.empty_like(hidden_states)
else:
final_shared_hidden_states = None
return final_shared_hidden_states, final_fused_hidden_states
def _maybe_gate(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
return router_logits
@property
def do_naive_dispatch_combine(self) -> bool:
return (
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
)
def _maybe_dispatch(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if self.do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstraction to better support PCP.
if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
return hidden_states, router_logits
def _maybe_combine(
self,
shared_output: torch.Tensor | None,
hidden_states: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
if self.do_naive_dispatch_combine:
hidden_states = get_ep_group().combine(
hidden_states, self.moe_config.is_sequence_parallel
)
if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().reduce_scatter(
hidden_states,
dim=0,
)
# need RS for shared_output?
if self.shared_experts is not None:
assert shared_output is not None
return shared_output, hidden_states
else:
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
if self.shared_experts is not None:
original_hidden_states = hidden_states
else:
original_hidden_states = None
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states)
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
original_hidden_states,
hidden_states,
)
fused_output = self.moe_forward(
hidden_states,
router_logits,
@@ -433,127 +659,42 @@ class DefaultMoERunner(MoERunner):
self._encode_layer_name(),
)
if self.shared_experts is not None:
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
else:
orig_hidden_dims = [transformed_hidden_dim]
return self._maybe_reduce_output(fused_output, og_hidden_dims)
return self._reduce_output(fused_output, orig_hidden_dims)
def _slice_and_copy_input(
self,
out_slice: torch.Tensor,
orig: torch.Tensor | None,
start: int,
end: int,
) -> torch.Tensor:
assert orig is not None
slice_size = end - start
orig_slice = orig[start:end, :]
if self.enable_dbo:
assert out_slice.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
out_slice = out_slice[batch_buffer_idx, :]
assert out_slice.size(0) >= slice_size
out_slice = out_slice[:slice_size, :]
out_slice.copy_(orig_slice, non_blocking=True)
return out_slice
def forward_impl_chunked(
self,
layer: torch.nn.Module,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
full_shared_input: torch.Tensor | None,
has_separate_shared_experts: bool,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
# Gate overlap not supported when chunking is enabled. Run the
# gate first.
router_logits = self._maybe_gate(hidden_states, router_logits)
final_shared_hidden_states, final_fused_hidden_states = (
self._allocate_dp_chunking_outputs(hidden_states, router_logits)
)
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
# TODO(bnell): Fix shared_expert_inputs w/chunking.
# assert shared_input is None, (
# "Routed input transform is not currently supported with DP chunking."
# )
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
shared_input = (
full_shared_input[chunk_start:chunk_end, :]
if full_shared_input is not None
else None
)
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
# This is only true when DBO has been enabled in the config.
# Both tensors will have an outer dimension for the ubatch id
if self.batched_hidden_states.dim() == 3:
assert self.batched_router_logits.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
else:
batched_hidden_states = self.batched_hidden_states
batched_router_logits = self.batched_router_logits
assert (
batched_hidden_states.size(0) # type: ignore
>= chunk_size
)
assert (
batched_router_logits.size(0) # type: ignore
>= chunk_size
)
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
shared_input = (
shared_input if shared_input is not None else staged_hidden_states
)
# Matrix multiply.
if self.quant_method.is_monolithic:
assert has_separate_shared_experts or self.shared_experts is None
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=staged_hidden_states,
router_logits=staged_router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=layer,
x=staged_hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
)
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
shared_output = self.shared_experts(shared_input)
final_hidden_states = (
shared_output,
final_hidden_states,
)
if not skip_result_store:
if self.shared_experts is None:
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states, non_blocking=True
)
else:
full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states[0], non_blocking=True
)
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states[1], non_blocking=True
)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
@@ -567,7 +708,7 @@ class DefaultMoERunner(MoERunner):
max_tokens_across_dispatchers, self.moe_config.sp_size
)
num_tokens = full_hidden_states.size(0)
num_tokens = hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
):
@@ -578,17 +719,55 @@ class DefaultMoERunner(MoERunner):
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(
chunk_sizes = ctx.dp_metadata.chunked_sizes(
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
):
process_chunk(
chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
)
with chunk_sizes:
hidden_states_chunk = self._slice_and_copy_input(
self.batched_hidden_states,
hidden_states,
chunk_start,
chunk_end,
)
router_logits_chunk = self._slice_and_copy_input(
self.batched_router_logits,
router_logits,
chunk_start,
chunk_end,
)
shared_input_chunk = (
shared_input[chunk_start:chunk_end, :]
if shared_input is not None
else None
)
shared_output_chunk, hidden_states_chunk = self._apply_quant_method(
layer=layer,
hidden_states=hidden_states_chunk,
router_logits=router_logits_chunk,
shared_input=shared_input_chunk,
)
# Store outputs
# TODO(bnell): document when chunk_start >= num_tokens
if chunk_start < num_tokens:
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
hidden_states_chunk, non_blocking=True
)
if self.shared_experts is not None:
assert shared_output_chunk is not None
assert final_shared_hidden_states is not None
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
shared_output_chunk, non_blocking=True
)
if self.shared_experts is None:
return full_fused_final_hidden_states
return final_fused_hidden_states
else:
return (full_shared_final_hidden_states, full_fused_final_hidden_states)
assert final_shared_hidden_states is not None
return (final_shared_hidden_states, final_fused_hidden_states)
def forward_impl(
self,
@@ -597,148 +776,51 @@ class DefaultMoERunner(MoERunner):
router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None
self.use_shared_experts_stream = (
current_platform.is_cuda()
and self.has_separate_shared_experts
and not self.use_dp_chunking
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
use_chunked_impl = self.use_dp_chunking
# Check if we need to run shared experts before matrix multiply because
# matrix multiply may modify the hidden_states.
run_shared_experts_before = (
self.has_separate_shared_experts and not self.use_shared_experts_stream
)
use_shared_experts_stream, shared_experts_input = (
# The shared experts stream must be set up before calling the gate so they
# can be overlapped.
if not run_shared_experts_before:
self._maybe_setup_shared_experts_stream(
hidden_states,
shared_input,
has_separate_shared_experts,
use_chunked_impl,
)
)
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
return self.forward_impl_chunked(
layer,
hidden_states,
router_logits,
shared_input,
has_separate_shared_experts,
)
# NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine = (
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
router_logits = self._maybe_gate(hidden_states, router_logits)
# TODO(bnell): parts of the dispatch/combine steps will go away once
# #32567 lands and the remaining kernels are made MKs. The PCP
# code will probably remain
hidden_states, router_logits = self._maybe_dispatch(
layer,
hidden_states,
router_logits,
)
ctx = get_forward_context()
sp_ctx = (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
if ctx.dp_metadata
else nullcontext()
shared_output, hidden_states = self._apply_quant_method(
layer=layer,
hidden_states=hidden_states,
router_logits=router_logits,
shared_input=shared_input,
run_shared_experts_before=run_shared_experts_before,
)
with sp_ctx:
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_input = (
shared_input if shared_input is not None else hidden_states
)
shared_output = self.shared_experts(shared_input)
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
# Matrix multiply.
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=layer,
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
)
if has_separate_shared_experts:
assert self.shared_experts is not None
if use_shared_experts_stream:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(shared_experts_input)
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
)
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(
states, self.moe_config.is_sequence_parallel
)
if self.moe_config.pcp_size > 1:
states = get_pcp_group().reduce_scatter(
states,
dim=0,
)
return states
if self.shared_experts is not None:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)
return self._maybe_combine(
shared_output,
hidden_states,
)