[Bugfix] Fix shared expert input for latent MoE in EP+DP (Nemotron-H) (#34087)

Signed-off-by: Tomer Natan <tbarnatan@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
TomerBN-Nvidia
2026-02-09 18:44:18 +02:00
committed by GitHub
parent d4f123cc48
commit 995bbf38f1
6 changed files with 30 additions and 3 deletions

View File

@@ -1228,13 +1228,28 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
Args:
shared_experts_input: Optional separate input for shared experts.
When latent MoE is used, hidden_states is the latent-projected
tensor (smaller dimension) used by routed experts, while
shared_experts_input is the original hidden_states (full
dimension) needed by the shared expert MLP.
"""
shared_output: torch.Tensor | None = None
# For latent MoE: shared experts need the original hidden_states
# (full hidden_size), not the latent-projected version used by
# routed experts.
se_hidden_states = (
shared_experts_input if shared_experts_input is not None else hidden_states
)
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
@@ -1247,7 +1262,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
shared_output = self.shared_experts(se_hidden_states)
else:
finalize_ret = self.prepare_finalize.finalize_async(
output,
@@ -1258,7 +1273,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
shared_output = self.shared_experts(se_hidden_states)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
@@ -1298,6 +1313,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@@ -1320,6 +1336,9 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- shared_experts_input (Optional[torch.Tensor]): Optional separate
input for shared experts. For latent MoE, this is the original
hidden_states before latent projection.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@@ -1368,4 +1387,5 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)