[MoE] Enable Shared/Routed Overlap For Latent MoE (Nemotron-H) (#32790)

Signed-off-by: dafrimi <dafrimi@nvidia.com>
This commit is contained in:
danielafrimi
2026-02-02 16:18:50 +02:00
committed by GitHub
parent 9eb58f8cf1
commit 0aca8b8c62
4 changed files with 303 additions and 58 deletions

View File

@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from contextlib import nullcontext
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager, nullcontext
from enum import Enum
from typing import Literal, cast, get_args, overload
@@ -351,6 +351,10 @@ class FusedMoE(CustomOp):
"Enabled separate cuda stream for MoE shared_experts", scope="local"
)
# For latent MoE: stores original hidden_states before routed_input_transform
# so shared_experts can use it for cloning (they need original dimension)
self._shared_experts_input: torch.Tensor | None = None
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
@@ -664,6 +668,39 @@ class FusedMoE(CustomOp):
def gate(self) -> torch.nn.Module | None:
return None
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Hook to transform hidden_states before passing to routed experts.
For latent MoE: transforms [S, hidden_size] → [S, moe_latent_size].
The original hidden_states is saved in _shared_experts_input so
shared_experts still receive the original [S, hidden_size].
Override in subclasses (e.g., SharedFusedMoE) for latent MoE.
"""
return hidden_states
@contextmanager
def _set_shared_experts_input(
self, value: torch.Tensor | None
) -> Generator[None, None, None]:
"""Context manager to safely set/clear _shared_experts_input."""
self._shared_experts_input = value
try:
yield
finally:
self._shared_experts_input = None
def _get_shared_experts_input(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Get input for shared experts.
For latent MoE: shared_experts need original [S, hidden_size],
not the transformed [S, latent_size] used by routed experts.
"""
return (
self._shared_experts_input
if self._shared_experts_input is not None
else hidden_states
)
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@@ -855,9 +892,11 @@ class FusedMoE(CustomOp):
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
shared_experts_input = self._get_shared_experts_input(hidden_states)
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
hidden_states_clone = shared_experts_input.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
@@ -1537,11 +1576,20 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
original_hidden_states = hidden_states
original_hidden_dim = hidden_states.shape[-1]
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states)
# This is the dimension after transform (for routed expert output slicing)
transformed_hidden_dim = hidden_states.shape[-1]
if self.hidden_size != transformed_hidden_dim:
hidden_states = F.pad(
hidden_states,
(0, self.hidden_size - og_hidden_states),
(0, self.hidden_size - transformed_hidden_dim),
mode="constant",
value=0.0,
)
@@ -1576,22 +1624,31 @@ class FusedMoE(CustomOp):
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, encode_layer_name()
)
return reduce_output(fused_output)[..., :og_hidden_states]
return reduce_output(fused_output)[..., :transformed_hidden_dim]
else:
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
shared_output, fused_output = self.forward_impl(
hidden_states, router_logits
)
with self._set_shared_experts_input(original_hidden_states):
shared_output, fused_output = self.forward_impl(
hidden_states, router_logits
)
else:
# Custom op handles setting/clearing _shared_experts_input internally
# We pass original tensor for shared experts (not transformed)
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name()
hidden_states,
router_logits,
encode_layer_name(),
original_hidden_states,
)
# shared_output uses original dimension (before transform)
# fused_output uses transformed dimension (after transform)
return (
reduce_output(shared_output)[..., :og_hidden_states],
reduce_output(fused_output)[..., :og_hidden_states],
reduce_output(shared_output)[..., :original_hidden_dim],
reduce_output(fused_output)[..., :transformed_hidden_dim],
)
@property
@@ -1831,7 +1888,8 @@ class FusedMoE(CustomOp):
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)
shared_input = self._get_shared_experts_input(hidden_states)
shared_output = self.shared_experts(shared_input)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
@@ -2021,19 +2079,34 @@ def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_experts_input: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name)
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits)
# Set here because torch.compile skips forward_native() setup code
# and calls this op directly. forward_impl() reads from this var.
with self._set_shared_experts_input(shared_experts_input):
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
shared_experts_input: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)
return shared_out, fused_out

View File

@@ -23,10 +23,12 @@ class SharedFusedMoE(FusedMoE):
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
routed_input_transform: torch.nn.Module | None = None,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self._routed_input_transform = routed_input_transform
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
@@ -56,6 +58,26 @@ class SharedFusedMoE(FusedMoE):
def is_internal_router(self) -> bool:
return self.gate is not None
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
"""
if self._routed_input_transform is not None:
result = self._routed_input_transform(hidden_states)
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
if isinstance(result, tuple):
return result[0]
return result
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,