[MoE] Enable Shared/Routed Overlap For Latent MoE (Nemotron-H) (#32790)

Signed-off-by: dafrimi <dafrimi@nvidia.com>
This commit is contained in:
danielafrimi
2026-02-02 16:18:50 +02:00
committed by GitHub
parent 9eb58f8cf1
commit 0aca8b8c62
4 changed files with 303 additions and 58 deletions

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for SharedFusedMoE with routed_input_transform.
Verifies that applying routed_input_transform inside SharedFusedMoE
produces the same results as applying the transform manually outside.
"""
import pytest
import torch
import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
class SimpleLinear(nn.Module):
"""A simple linear transform mimicking latent projection in latent MoE."""
def __init__(self, in_features: int, out_features: int, dtype: torch.dtype):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_features, in_features, device="cuda", dtype=dtype) / 10
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.linear(x, self.weight)
class SimpleSharedExperts(nn.Module):
"""A simple 2-layer MLP mimicking shared experts."""
def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype):
super().__init__()
self.up = nn.Linear(
hidden_size, intermediate_size * 2, bias=False, device="cuda", dtype=dtype
)
self.down = nn.Linear(
intermediate_size, hidden_size, bias=False, device="cuda", dtype=dtype
)
with torch.no_grad():
self.up.weight.div_(10)
self.down.weight.div_(10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up = self.up(x)
gate, up = gate_up.chunk(2, dim=-1)
return self.down(nn.functional.silu(gate) * up)
@pytest.fixture(autouse=True)
def setup_cuda():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
torch.set_default_device("cuda")
@pytest.mark.parametrize("num_tokens", [1, 32])
@pytest.mark.parametrize("hidden_size,latent_size", [(256, 128), (128, 64)])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_routed_input_transform_inside_vs_outside(
num_tokens: int,
hidden_size: int,
latent_size: int,
dtype: torch.dtype,
dist_init,
workspace_init,
):
"""Compare SharedFusedMoE with transform inside vs manually applying outside.
Method A (inside): SharedFusedMoE with routed_input_transform
Method B (outside): Manually transform, then SharedFusedMoE without transform
"""
torch.manual_seed(42)
num_experts = 8
top_k = 2
intermediate_size = hidden_size * 2
vllm_config = VllmConfig()
vllm_config.compilation_config.static_forward_context = dict()
shared_experts = SimpleSharedExperts(hidden_size, intermediate_size, dtype)
routed_transform = SimpleLinear(hidden_size, latent_size, dtype)
with set_current_vllm_config(vllm_config):
# Method A: SharedFusedMoE WITH routed_input_transform
moe_with_transform = SharedFusedMoE(
shared_experts=shared_experts,
routed_input_transform=routed_transform,
num_experts=num_experts,
top_k=top_k,
hidden_size=latent_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=True,
params_dtype=dtype,
tp_size=1,
dp_size=1,
pcp_size=1,
prefix="moe_with_transform",
)
# Method B: SharedFusedMoE WITHOUT routed_input_transform
# Note: shared_experts=None because when transform is done outside,
moe_without_transform = SharedFusedMoE(
shared_experts=None,
routed_input_transform=None,
num_experts=num_experts,
top_k=top_k,
hidden_size=latent_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=True,
params_dtype=dtype,
tp_size=1,
dp_size=1,
pcp_size=1,
prefix="moe_without_transform",
)
with torch.no_grad():
moe_without_transform.w13_weight.copy_(moe_with_transform.w13_weight)
moe_without_transform.w2_weight.copy_(moe_with_transform.w2_weight)
moe_with_transform.quant_method.process_weights_after_loading(
moe_with_transform
)
moe_without_transform.quant_method.process_weights_after_loading(
moe_without_transform
)
hidden_states = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
with set_forward_context(None, vllm_config, num_tokens=num_tokens):
shared_out_A, routed_out_A = moe_with_transform(
hidden_states, router_logits
)
transformed_hidden = routed_transform(hidden_states)
shared_out_B, routed_out_B = moe_without_transform(
transformed_hidden, router_logits
)
torch.testing.assert_close(
routed_out_A,
routed_out_B,
atol=1e-3,
rtol=1e-3,
msg="Routed output should match: transform inside vs outside",
)
expected_shared_out = shared_experts(hidden_states)
torch.testing.assert_close(
shared_out_A,
expected_shared_out,
atol=1e-3,
rtol=1e-3,
)

View File

@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable from collections.abc import Callable, Generator, Iterable
from contextlib import nullcontext from contextlib import contextmanager, nullcontext
from enum import Enum from enum import Enum
from typing import Literal, cast, get_args, overload from typing import Literal, cast, get_args, overload
@@ -351,6 +351,10 @@ class FusedMoE(CustomOp):
"Enabled separate cuda stream for MoE shared_experts", scope="local" "Enabled separate cuda stream for MoE shared_experts", scope="local"
) )
# For latent MoE: stores original hidden_states before routed_input_transform
# so shared_experts can use it for cloning (they need original dimension)
self._shared_experts_input: torch.Tensor | None = None
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
@@ -664,6 +668,39 @@ class FusedMoE(CustomOp):
def gate(self) -> torch.nn.Module | None: def gate(self) -> torch.nn.Module | None:
return None return None
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Hook to transform hidden_states before passing to routed experts.
For latent MoE: transforms [S, hidden_size] → [S, moe_latent_size].
The original hidden_states is saved in _shared_experts_input so
shared_experts still receive the original [S, hidden_size].
Override in subclasses (e.g., SharedFusedMoE) for latent MoE.
"""
return hidden_states
@contextmanager
def _set_shared_experts_input(
self, value: torch.Tensor | None
) -> Generator[None, None, None]:
"""Context manager to safely set/clear _shared_experts_input."""
self._shared_experts_input = value
try:
yield
finally:
self._shared_experts_input = None
def _get_shared_experts_input(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Get input for shared experts.
For latent MoE: shared_experts need original [S, hidden_size],
not the transformed [S, latent_size] used by routed experts.
"""
return (
self._shared_experts_input
if self._shared_experts_input is not None
else hidden_states
)
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
@@ -855,9 +892,11 @@ class FusedMoE(CustomOp):
if use_shared_experts_stream: if use_shared_experts_stream:
assert self.shared_experts_stream is not None assert self.shared_experts_stream is not None
shared_experts_input = self._get_shared_experts_input(hidden_states)
# Clone BEFORE switching streams to avoid race condition # Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states. # where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone() hidden_states_clone = shared_experts_input.clone()
# Record that the clone will be used by shared_experts_stream # Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone # to avoid gc issue from deallocation of hidden_states_clone
@@ -1537,11 +1576,20 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1] # For latent MoE: save ORIGINAL hidden_states before transform
if self.hidden_size != og_hidden_states: # (shared_experts need original dimension, routed experts use transformed)
original_hidden_states = hidden_states
original_hidden_dim = hidden_states.shape[-1]
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states)
# This is the dimension after transform (for routed expert output slicing)
transformed_hidden_dim = hidden_states.shape[-1]
if self.hidden_size != transformed_hidden_dim:
hidden_states = F.pad( hidden_states = F.pad(
hidden_states, hidden_states,
(0, self.hidden_size - og_hidden_states), (0, self.hidden_size - transformed_hidden_dim),
mode="constant", mode="constant",
value=0.0, value=0.0,
) )
@@ -1576,22 +1624,31 @@ class FusedMoE(CustomOp):
fused_output = torch.ops.vllm.moe_forward( fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, encode_layer_name() hidden_states, router_logits, encode_layer_name()
) )
return reduce_output(fused_output)[..., :og_hidden_states] return reduce_output(fused_output)[..., :transformed_hidden_dim]
else: else:
if current_platform.is_tpu() or current_platform.is_cpu(): if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we # TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op. # will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl. # Note: CPU doesn't require wrapped forward_impl.
shared_output, fused_output = self.forward_impl( with self._set_shared_experts_input(original_hidden_states):
hidden_states, router_logits shared_output, fused_output = self.forward_impl(
) hidden_states, router_logits
)
else: else:
# Custom op handles setting/clearing _shared_experts_input internally
# We pass original tensor for shared experts (not transformed)
shared_output, fused_output = torch.ops.vllm.moe_forward_shared( shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name() hidden_states,
router_logits,
encode_layer_name(),
original_hidden_states,
) )
# shared_output uses original dimension (before transform)
# fused_output uses transformed dimension (after transform)
return ( return (
reduce_output(shared_output)[..., :og_hidden_states], reduce_output(shared_output)[..., :original_hidden_dim],
reduce_output(fused_output)[..., :og_hidden_states], reduce_output(fused_output)[..., :transformed_hidden_dim],
) )
@property @property
@@ -1831,7 +1888,8 @@ class FusedMoE(CustomOp):
# because matrix multiply maybe modify the hidden_states. # because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream: if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states) shared_input = self._get_shared_experts_input(hidden_states)
shared_output = self.shared_experts(shared_input)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For # NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe # simplicity, AgRsAll2All was added separately for PCP here. Maybe
@@ -2021,19 +2079,34 @@ def moe_forward_shared(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
shared_experts_input: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is not None assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits)
# Set here because torch.compile skips forward_native() setup code
# and calls this op directly. forward_impl() reads from this var.
with self._set_shared_experts_input(shared_experts_input):
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake( def moe_forward_shared_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
shared_experts_input: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) # Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states) fused_out = torch.empty_like(hidden_states)
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)
return shared_out, fused_out return shared_out, fused_out

View File

@@ -23,10 +23,12 @@ class SharedFusedMoE(FusedMoE):
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None, gate: torch.nn.Module | None = None,
use_overlapped: bool = True, use_overlapped: bool = True,
routed_input_transform: torch.nn.Module | None = None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self._shared_experts = shared_experts self._shared_experts = shared_experts
self._routed_input_transform = routed_input_transform
# Disable shared expert overlap if: # Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues # - we are using eplb with non-default backend, because of correctness issues
@@ -56,6 +58,26 @@ class SharedFusedMoE(FusedMoE):
def is_internal_router(self) -> bool: def is_internal_router(self) -> bool:
return self.gate is not None return self.gate is not None
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
"""
if self._routed_input_transform is not None:
result = self._routed_input_transform(hidden_states)
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
if isinstance(result, tuple):
return result[0]
return result
return hidden_states
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@@ -188,10 +188,29 @@ class NemotronHMoE(nn.Module):
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
if self.use_latent_moe:
self.fc1_latent_proj = ReplicatedLinear(
input_size=config.hidden_size,
output_size=self.moe_hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
disable_tp=self.is_sequence_parallel,
prefix=f"{prefix}.fc1_latent_proj",
)
self.fc2_latent_proj = ReplicatedLinear(
input_size=self.moe_hidden_size,
output_size=config.hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
disable_tp=self.is_sequence_parallel,
prefix=f"{prefix}.fc2_latent_proj",
)
else:
self.fc1_latent_proj = None
self.fc2_latent_proj = None
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
# TODO: make it possible for shared experts to have shared_experts=self.shared_experts,
# different input in SharedFusedMoE
shared_experts=self.shared_experts if not self.use_latent_moe else None,
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=self.moe_hidden_size, hidden_size=self.moe_hidden_size,
@@ -211,30 +230,9 @@ class NemotronHMoE(nn.Module):
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
router_logits_dtype=router_logits_dtype, router_logits_dtype=router_logits_dtype,
routed_input_transform=self.fc1_latent_proj,
) )
if self.use_latent_moe:
self.fc1_latent_proj = ReplicatedLinear(
input_size=config.hidden_size,
output_size=self.moe_hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
disable_tp=self.is_sequence_parallel,
prefix=f"{prefix}.fc1_latent_proj",
)
self.fc2_latent_proj = ReplicatedLinear(
input_size=self.moe_hidden_size,
output_size=config.hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
disable_tp=self.is_sequence_parallel,
prefix=f"{prefix}.fc2_latent_proj",
)
else:
self.fc1_latent_proj = None
self.fc2_latent_proj = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
@@ -244,38 +242,28 @@ class NemotronHMoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
shared_output = None
if self.use_latent_moe:
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
hidden_states, _ = self.fc1_latent_proj(hidden_states)
fused_moe_out = self.experts( # SharedFusedMoE handles:
# - shared experts (with original hidden_states)
# - routed_input_transform (fc1_latent_proj) for latent MoE
# - multistream parallelism between shared and routed experts
shared_output, final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.use_latent_moe:
_, final_hidden_states = fused_moe_out
else:
shared_output, final_hidden_states = fused_moe_out
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None: elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor shared_output *= 1.0 / self.routed_scaling_factor
# TODO: currently latent up_proj is done before all-reduce for simplicity. # TODO: See SharedFusedMoE.apply_routed_input_transform
# if and when shared experts will be part of SharedFusedMoE, # for bandwidth optimization
# we should do the up_proj after all-reduce,
# to have the all-reduce in the smaller latent dimension.
if self.use_latent_moe: if self.use_latent_moe:
final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states) final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states)
if self.shared_experts is not None: if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output final_hidden_states += shared_output
if self.is_sequence_parallel: if self.is_sequence_parallel: