[Model] Apply shared experts overlap optimization to all models with shared experts (#26145)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -206,7 +206,7 @@ class AriaProjector(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class AriaFusedMoE(FusedMoE):
|
||||
class AriaFusedMoE(SharedFusedMoE):
|
||||
def weight_loader(
|
||||
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
|
||||
) -> None:
|
||||
@@ -260,7 +260,16 @@ class AriaTextMoELayer(nn.Module):
|
||||
torch.empty((self.config.moe_num_experts, self.config.hidden_size))
|
||||
)
|
||||
|
||||
self.shared_experts = LlamaMLP(
|
||||
config.hidden_size,
|
||||
config.intermediate_size * config.moe_num_shared_experts,
|
||||
"silu",
|
||||
quant_config=quant_config,
|
||||
bias=config.mlp_bias,
|
||||
)
|
||||
|
||||
self.experts = AriaFusedMoE(
|
||||
shared_experts=self.shared_experts,
|
||||
num_experts=config.moe_num_experts,
|
||||
top_k=config.moe_topk,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -269,13 +278,6 @@ class AriaTextMoELayer(nn.Module):
|
||||
reduce_results=True,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
self.shared_experts = LlamaMLP(
|
||||
config.hidden_size,
|
||||
config.intermediate_size * config.moe_num_shared_experts,
|
||||
"silu",
|
||||
quant_config=quant_config,
|
||||
bias=config.mlp_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -291,12 +293,12 @@ class AriaTextMoELayer(nn.Module):
|
||||
|
||||
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
|
||||
|
||||
hidden_states_copy = hidden_states.clone()
|
||||
# NOTE: hidden_states will be modified inplace by `FusedMoE`
|
||||
sparse_expert_output = self.experts(hidden_states, router_output)
|
||||
shared_expert_output = self.shared_experts(hidden_states_copy)
|
||||
|
||||
return sparse_expert_output + shared_expert_output
|
||||
if self.shared_experts is not None:
|
||||
return sparse_expert_output[0] + sparse_expert_output[1]
|
||||
else:
|
||||
return sparse_expert_output
|
||||
|
||||
|
||||
class AriaTextDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
Reference in New Issue
Block a user