[Bugfix] Fix shared expert input for latent MoE in EP+DP (Nemotron-H) (#34087)
Signed-off-by: Tomer Natan <tbarnatan@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -139,7 +139,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# work with SP. This will be removed in follow up after we get
|
||||
# rid of the FlashInfer specific P/F function.
|
||||
# TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As.
|
||||
return not moe_parallel_config.is_sequence_parallel
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
|
||||
@@ -101,4 +101,5 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
)
|
||||
|
||||
@@ -1228,13 +1228,28 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The _finalize method is a wrapper around self.prepare_finalize.finalize
|
||||
that handles DBO, async and shared expert overlap.
|
||||
|
||||
Args:
|
||||
shared_experts_input: Optional separate input for shared experts.
|
||||
When latent MoE is used, hidden_states is the latent-projected
|
||||
tensor (smaller dimension) used by routed experts, while
|
||||
shared_experts_input is the original hidden_states (full
|
||||
dimension) needed by the shared expert MLP.
|
||||
"""
|
||||
shared_output: torch.Tensor | None = None
|
||||
|
||||
# For latent MoE: shared experts need the original hidden_states
|
||||
# (full hidden_size), not the latent-projected version used by
|
||||
# routed experts.
|
||||
se_hidden_states = (
|
||||
shared_experts_input if shared_experts_input is not None else hidden_states
|
||||
)
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
|
||||
@@ -1247,7 +1262,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
shared_output = self.shared_experts(se_hidden_states)
|
||||
else:
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
@@ -1258,7 +1273,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
shared_output = self.shared_experts(se_hidden_states)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
@@ -1298,6 +1313,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
@@ -1320,6 +1336,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is
|
||||
1.
|
||||
- shared_experts_input (Optional[torch.Tensor]): Optional separate
|
||||
input for shared experts. For latent MoE, this is the original
|
||||
hidden_states before latent projection.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
@@ -1368,4 +1387,5 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
@@ -361,6 +361,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
)
|
||||
|
||||
|
||||
@@ -672,6 +673,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
)
|
||||
|
||||
|
||||
@@ -1077,6 +1079,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1023,6 +1023,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -980,6 +980,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
)
|
||||
|
||||
|
||||
@@ -1550,6 +1551,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user