[Model] Apply shared experts overlap optimization to all models with shared experts (#26145)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
@@ -42,6 +43,7 @@ __all__ = [
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"SharedFusedMoE",
|
||||
"activation_without_mul",
|
||||
"override_config",
|
||||
"get_config",
|
||||
|
||||
@@ -18,13 +18,21 @@ class SharedFusedMoE(FusedMoE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: torch.nn.Module,
|
||||
shared_experts: Optional[torch.nn.Module],
|
||||
use_overlapped: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self.use_overlapped = use_overlapped
|
||||
# Disable shared expert overlap if EP is disabled or we are not using
|
||||
# flashinfer + DP since there is nothing to be gained in this case.
|
||||
# Disabling the overlap optimization also prevents the shared experts
|
||||
# from being hidden from torch.compile.
|
||||
self.use_overlapped = (
|
||||
use_overlapped
|
||||
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
@@ -36,16 +44,19 @@ class SharedFusedMoE(FusedMoE):
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.use_overlapped:
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
if self._shared_experts is not None:
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
|
||||
# Reduce outputs if necessary, since the MLP should
|
||||
# have been created with reduce_results=False.
|
||||
if (
|
||||
self.reduce_results
|
||||
and self.tp_size > 1
|
||||
and self.must_reduce_shared_expert_outputs()
|
||||
):
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
# Reduce shared expert outputs if necessary, since the MLP
|
||||
# should have been created with reduce_results=False.
|
||||
if (
|
||||
self.reduce_results
|
||||
and self.tp_size > 1
|
||||
and self.must_reduce_shared_expert_outputs()
|
||||
):
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
else:
|
||||
shared_out = None
|
||||
|
||||
fused_out = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
@@ -741,6 +741,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
self.rocm_aiter_moe_enabled = False
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Lazy import to avoid importing triton too early.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
|
||||
__all__ = ["SharedFusedMoE"]
|
||||
Reference in New Issue
Block a user