[Model] Apply shared experts overlap optimization to all models with shared experts (#26145)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -43,7 +43,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -276,22 +276,6 @@ class BailingMoE(nn.Module):
|
||||
# default value for scoring_func
|
||||
self.score_function = "softmax"
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=self.norm_expert_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=self.score_function,
|
||||
e_score_correction_bias=self.gate.expert_bias,
|
||||
num_expert_group=self.n_group,
|
||||
topk_group=self.topk_group,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
)
|
||||
|
||||
if self.num_shared_experts > 0:
|
||||
if hasattr(config, "moe_shared_expert_intermediate_size"):
|
||||
intermediate_size = config.moe_shared_expert_intermediate_size
|
||||
@@ -308,11 +292,27 @@ class BailingMoE(nn.Module):
|
||||
else:
|
||||
self.shared_experts = None
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_experts,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=self.norm_expert_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=self.score_function,
|
||||
e_score_correction_bias=self.gate.expert_bias,
|
||||
num_expert_group=self.n_group,
|
||||
topk_group=self.topk_group,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_size)
|
||||
if self.shared_experts:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states.to(self.router_dtype))
|
||||
router_logits = router_logits.to(hidden_states.dtype)
|
||||
@@ -321,9 +321,14 @@ class BailingMoE(nn.Module):
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output, final_hidden_states = final_hidden_states
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
if self.shared_experts:
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
if self.tp_size > 1:
|
||||
@@ -475,7 +480,7 @@ class BailingMoeModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
return SharedFusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
Reference in New Issue
Block a user