[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
|
||||
@@ -73,7 +74,18 @@ class Llama4MoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.router")
|
||||
|
||||
self.experts = FusedMoE(
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_expert,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -83,22 +95,13 @@ class Llama4MoE(nn.Module):
|
||||
reduce_results=False,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts")
|
||||
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
shared_out = self.shared_expert(hidden_states)
|
||||
routed_out = self.experts(
|
||||
|
||||
shared_out, routed_out = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user