[Bugfix] Fix shared expert input for latent MoE in EP+DP (Nemotron-H) (#34087)

Signed-off-by: Tomer Natan <tbarnatan@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
TomerBN-Nvidia
2026-02-09 18:44:18 +02:00
committed by GitHub
parent d4f123cc48
commit 995bbf38f1
6 changed files with 30 additions and 3 deletions

View File

@@ -139,7 +139,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function.
# TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As.
return not moe_parallel_config.is_sequence_parallel
return True
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:

View File

@@ -101,4 +101,5 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
shared_experts_input=layer._get_shared_experts_input(x),
)

View File

@@ -1228,13 +1228,28 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
Args:
shared_experts_input: Optional separate input for shared experts.
When latent MoE is used, hidden_states is the latent-projected
tensor (smaller dimension) used by routed experts, while
shared_experts_input is the original hidden_states (full
dimension) needed by the shared expert MLP.
"""
shared_output: torch.Tensor | None = None
# For latent MoE: shared experts need the original hidden_states
# (full hidden_size), not the latent-projected version used by
# routed experts.
se_hidden_states = (
shared_experts_input if shared_experts_input is not None else hidden_states
)
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
@@ -1247,7 +1262,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
shared_output = self.shared_experts(se_hidden_states)
else:
finalize_ret = self.prepare_finalize.finalize_async(
output,
@@ -1258,7 +1273,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
shared_output = self.shared_experts(se_hidden_states)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
@@ -1298,6 +1313,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@@ -1320,6 +1336,9 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- shared_experts_input (Optional[torch.Tensor]): Optional separate
input for shared experts. For latent MoE, this is the original
hidden_states before latent projection.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@@ -1368,4 +1387,5 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)

View File

@@ -361,6 +361,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
)
@@ -672,6 +673,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
)
@@ -1077,6 +1079,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
)
@property

View File

@@ -1023,6 +1023,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
)

View File

@@ -980,6 +980,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
)
@@ -1550,6 +1551,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=layer._get_shared_experts_input(x),
)