[Kernels] Overlap shared experts with send/recv (#23273)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-03 12:35:18 -04:00
committed by GitHub
parent fa4311d85f
commit e9b92dcd89
32 changed files with 885 additions and 227 deletions

View File

@@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
@@ -73,7 +74,18 @@ class Llama4MoE(nn.Module):
quant_config=None,
prefix=f"{prefix}.router")
self.experts = FusedMoE(
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False,
)
self.experts = SharedFusedMoE(
shared_experts=self.shared_expert,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
@@ -83,22 +95,13 @@ class Llama4MoE(nn.Module):
reduce_results=False,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
prefix=f"{prefix}.experts",
)
def forward(self, hidden_states):
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
shared_out, routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)