[Bugfix][Rocm] Fix shared expert weight loading failure in DeepSeek-MTP (#27563)
Signed-off-by: zhyajie <yajizhan@amd.com> Co-authored-by: zhyajie <yajizhan@amd.com>
This commit is contained in:
@@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM(
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
|
||||
"mlp.shared_experts" in name
|
||||
is_fusion_moe_shared_experts_layer = (
|
||||
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
@@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM(
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
if is_fuse_shared_experts_layer:
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
@@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM(
|
||||
# appended expert slots mlp.experts.{n_routed_experts + j}.*
|
||||
# accordingly.
|
||||
num_chunks = 1
|
||||
if is_fuse_shared_experts_layer:
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
|
||||
# Determine split axis based on op type
|
||||
# gate/up: ColumnParallel → split along dim 0
|
||||
@@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM(
|
||||
chunk_name = name
|
||||
weight_to_load = loaded_weight
|
||||
|
||||
if is_fuse_shared_experts_layer:
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
if split_dim == 0:
|
||||
weight_to_load = loaded_weight[
|
||||
j * chunk_size : (j + 1) * chunk_size, :
|
||||
@@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM(
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
if not is_fuse_shared_experts_layer:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
name = name_mapped
|
||||
else:
|
||||
loaded_params.add(name_mapped)
|
||||
@@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if not is_fuse_shared_experts_layer:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user