[Model] Apply shared experts overlap optimization to all models with shared experts (#26145)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-10-09 11:31:04 -04:00
committed by GitHub
parent 3b736e1c38
commit 47e66c24e2
15 changed files with 285 additions and 297 deletions

View File

@@ -13,7 +13,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -206,7 +206,7 @@ class AriaProjector(nn.Module):
return out
class AriaFusedMoE(FusedMoE):
class AriaFusedMoE(SharedFusedMoE):
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
) -> None:
@@ -260,7 +260,16 @@ class AriaTextMoELayer(nn.Module):
torch.empty((self.config.moe_num_experts, self.config.hidden_size))
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)
self.experts = AriaFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.moe_num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
@@ -269,13 +278,6 @@ class AriaTextMoELayer(nn.Module):
reduce_results=True,
prefix=f"{prefix}.experts",
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
@@ -291,12 +293,12 @@ class AriaTextMoELayer(nn.Module):
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
hidden_states_copy = hidden_states.clone()
# NOTE: hidden_states will be modified inplace by `FusedMoE`
sparse_expert_output = self.experts(hidden_states, router_output)
shared_expert_output = self.shared_experts(hidden_states_copy)
return sparse_expert_output + shared_expert_output
if self.shared_experts is not None:
return sparse_expert_output[0] + sparse_expert_output[1]
else:
return sparse_expert_output
class AriaTextDecoderLayer(LlamaDecoderLayer):