[Model] Apply shared experts overlap optimization to all models with shared experts (#26145)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -40,7 +40,7 @@ from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -79,6 +79,7 @@ class Qwen2MoeMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
expert_gate: Optional[torch.nn.Linear] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -102,12 +103,17 @@ class Qwen2MoeMLP(nn.Module):
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
self.expert_gate = expert_gate
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
out = self.act_fn(gate_up)
|
||||
out, _ = self.down_proj(out)
|
||||
|
||||
if self.expert_gate is not None:
|
||||
out = F.sigmoid(self.expert_gate(x)) * out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
@@ -126,7 +132,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
f"the number of experts {config.num_experts}."
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.shared_expert_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
expert_gate=self.shared_expert_gate,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_expert,
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -137,46 +167,19 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.shared_expert_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
shared_output = (
|
||||
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
||||
)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.shared_expert is not None:
|
||||
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states
|
||||
@@ -418,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
return SharedFusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
Reference in New Issue
Block a user