[Bugfix][Rocm] Fix shared expert weight loading failure in DeepSeek-MTP (#27563)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
This commit is contained in:
杰兮
2025-11-24 18:16:52 +08:00
committed by GitHub
parent 68dfe28eae
commit 8005e606bf
2 changed files with 121 additions and 46 deletions

View File

@@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
"mlp.shared_experts" in name
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM(
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fuse_shared_experts_layer:
if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
@@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM(
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fuse_shared_experts_layer:
if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
@@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM(
chunk_name = name
weight_to_load = loaded_weight
if is_fuse_shared_experts_layer:
if is_fusion_moe_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
@@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM(
return_success=True,
)
if success:
if not is_fuse_shared_experts_layer:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
@@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not is_fuse_shared_experts_layer:
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
return loaded_params