[Model] Apply shared experts overlap optimization to all models with shared experts (#26145)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-10-09 11:31:04 -04:00
committed by GitHub
parent 3b736e1c38
commit 47e66c24e2
15 changed files with 285 additions and 297 deletions

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON
@@ -42,6 +43,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"SharedFusedMoE",
"activation_without_mul",
"override_config",
"get_config",

View File

@@ -18,13 +18,21 @@ class SharedFusedMoE(FusedMoE):
def __init__(
self,
shared_experts: torch.nn.Module,
shared_experts: Optional[torch.nn.Module],
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
# Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
and self._shared_experts is not None
)
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
@@ -36,16 +44,19 @@ class SharedFusedMoE(FusedMoE):
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if (
self.reduce_results
and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self.reduce_results
and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None
fused_out = super().forward(
hidden_states=hidden_states,

View File

@@ -741,6 +741,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None
layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (

View File

@@ -1,5 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import SharedFusedMoE
__all__ = ["SharedFusedMoE"]