[ROCm][FEAT] Fuse DeepSeek shared experts into AITER fused_moe ops (#24097)
Signed-off-by: chenjun <junchen2@amd.com> Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Co-authored-by: valarLip <103567126+valarLip@users.noreply.github.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -50,6 +50,10 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_fusion_shared_expert_enabled,
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -203,7 +207,10 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
if config.n_shared_experts is None:
|
||||
if (
|
||||
config.n_shared_experts is None
|
||||
or is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
):
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
@@ -233,11 +240,17 @@ class DeepseekV2MoE(nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
# aiter applies routed_scaling_factor internally
|
||||
routed_scaling_factor=1.0
|
||||
if not is_rocm_aiter_moe_enabled()
|
||||
else self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
n_shared_experts=config.n_shared_experts
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
else None,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -258,16 +271,15 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output, final_hidden_states = fused_moe_out
|
||||
else:
|
||||
shared_output = None
|
||||
final_hidden_states = fused_moe_out
|
||||
shared_output, final_hidden_states = fused_moe_out
|
||||
if self.shared_experts is None:
|
||||
assert shared_output is None
|
||||
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
if not is_rocm_aiter_moe_enabled():
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
shared_output *= 1.0 / self.routed_scaling_factor
|
||||
@@ -1316,7 +1328,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_experts=self.config.n_routed_experts
|
||||
+ (
|
||||
self.config.n_shared_experts
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
else 0
|
||||
),
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
@@ -1330,6 +1347,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fuse_shared_experts_layer = (
|
||||
is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
and ("mlp.shared_experts" in name)
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
@@ -1342,6 +1364,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
if is_fuse_shared_experts_layer:
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
@@ -1366,65 +1390,115 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or not
|
||||
# here since otherwise we may skip experts with other
|
||||
# available replicas.
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
# Special handling: when AITER fusion_shared_experts is enabled,
|
||||
# checkpoints may provide a single widened shared_experts tensor
|
||||
# without explicit expert indices
|
||||
# (e.g. ...mlp.shared_experts.gate_proj.weight).
|
||||
# For models with multiple shared experts, split that tensor
|
||||
# evenly into per-shared-expert slices and load them into
|
||||
# appended expert slots mlp.experts.{n_routed_experts + j}.*
|
||||
# accordingly.
|
||||
num_chunks = 1
|
||||
if is_fuse_shared_experts_layer:
|
||||
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
|
||||
# Determine split axis based on op type
|
||||
# gate/up: ColumnParallel → split along dim 0
|
||||
# down: RowParallel → split along dim 1
|
||||
split_dim = 1 if "down_proj.weight" in name else 0
|
||||
total = loaded_weight.shape[split_dim]
|
||||
assert total % num_chunks == 0, (
|
||||
f"Shared expert weight dim {total} "
|
||||
f"not divisible by num_chunks {num_chunks}"
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
chunk_size = total // num_chunks
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
for j in range(num_chunks):
|
||||
chunk_name = name
|
||||
weight_to_load = loaded_weight
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_fuse_shared_experts_layer:
|
||||
if split_dim == 0:
|
||||
weight_to_load = loaded_weight[
|
||||
j * chunk_size : (j + 1) * chunk_size, :
|
||||
]
|
||||
else:
|
||||
weight_to_load = loaded_weight[
|
||||
:, j * chunk_size : (j + 1) * chunk_size
|
||||
]
|
||||
# Synthesize an expert-style name so expert mapping
|
||||
# can route it
|
||||
chunk_name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts + j}",
|
||||
)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Use expert_params_mapping to locate the destination
|
||||
# param and delegate to its expert-aware weight_loader
|
||||
# with expert_id.
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in chunk_name:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = chunk_name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
weight_to_load,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
if not is_fuse_shared_experts_layer:
|
||||
name = name_mapped
|
||||
else:
|
||||
loaded_params.add(name_mapped)
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if not is_fuse_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user