[BugFix] Support EP/DP + EPLB with MTP (#25311)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -166,7 +166,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
@@ -1122,7 +1122,6 @@ class DeepseekV2Model(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekV2DecoderLayer(
|
||||
@@ -1172,7 +1171,50 @@ class DeepseekV2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA):
|
||||
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
|
||||
moe_mlp_layers: list[DeepseekV2MoE]
|
||||
"""
|
||||
List of MoE MLP layers in the model.
|
||||
"""
|
||||
|
||||
def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
|
||||
if example_moe is None:
|
||||
self.num_moe_layers = 0
|
||||
self.num_expert_groups = 0
|
||||
self.num_logical_experts = 0
|
||||
self.num_physical_experts = 0
|
||||
self.num_local_physical_experts = 0
|
||||
self.num_routed_experts = 0
|
||||
self.num_shared_experts = 0
|
||||
self.num_redundant_experts = 0
|
||||
logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
|
||||
else:
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for moe in self.moe_mlp_layers:
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
|
||||
class DeepseekV2ForCausalLM(
|
||||
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
@@ -1213,13 +1255,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = (
|
||||
self.config.num_hidden_layers - self.config.first_k_dense_replace
|
||||
)
|
||||
self.set_moe_parameters()
|
||||
|
||||
def set_moe_parameters(self):
|
||||
self.expert_weights = []
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
|
||||
self.num_expert_groups = config.n_group
|
||||
self.num_expert_groups = self.config.n_group
|
||||
|
||||
self.moe_layers: list[SharedFusedMoE] = []
|
||||
self.moe_layers = []
|
||||
self.moe_mlp_layers = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
@@ -1229,50 +1277,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
|
||||
if isinstance(layer.mlp, DeepseekV2MoE):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.mlp
|
||||
self.moe_mlp_layers.append(layer.mlp)
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_moe is None:
|
||||
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
|
||||
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
for layer_idx, layer in enumerate(self.moe_layers):
|
||||
# Register the expert weights.
|
||||
self.expert_weights.append(layer.get_expert_weights())
|
||||
layer.set_eplb_state(
|
||||
moe_layer_idx=layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer.mlp, DeepseekV2MoE):
|
||||
moe = layer.mlp
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
Reference in New Issue
Block a user