[BugFix] Support EP/DP + EPLB with MTP (#25311)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Ilya Markov
2025-11-05 16:22:17 +01:00
committed by GitHub
parent 5d16d0fa62
commit e50c454672
27 changed files with 957 additions and 529 deletions

View File

@@ -125,7 +125,7 @@ class MoEMixin(MixtureOfExperts):
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
):
for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers):
for moe_layer_idx, mlp_layer in enumerate(self.mlp_moe_layers):
mlp_layer.experts.set_eplb_state(
moe_layer_idx=moe_layer_idx,
expert_load_view=expert_load_view,
@@ -142,7 +142,7 @@ class MoEMixin(MixtureOfExperts):
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for mlp in self.mlp_layers:
for mlp in self.mlp_moe_layers:
mlp.n_local_physical_experts = num_local_physical_experts
mlp.n_physical_experts = num_physical_experts
mlp.n_redundant_experts = self.num_redundant_experts
@@ -240,7 +240,8 @@ class MoEMixin(MixtureOfExperts):
# MixtureOfExperts mixin settings
ep_size = get_ep_group().world_size
self.mlp_layers = [] # Used for MixtureOfExperts methods
self.mlp_moe_layers = [] # Used for MixtureOfExperts methods
self.moe_layers = []
self.expert_weights = []
self.num_moe_layers = 0
self.num_expert_groups = 1 if num_expert_group is None else num_expert_group
@@ -298,7 +299,8 @@ class MoEMixin(MixtureOfExperts):
mlp.experts = fused_experts
log_replacement(qual_name, experts, fused_experts)
# Update MixtureOfExperts mixin state
self.mlp_layers.append(mlp)
self.mlp_moe_layers.append(mlp)
self.moe_layers.append(fused_experts)
self.expert_weights.append(fused_experts.get_expert_weights())
self.num_moe_layers += 1
# If results are not all-reduced in FusedMoE, ensure they