[MoE Refactor] Introduce MoERunner abstraction and move execution logic from FusedMoE to DefaultMoERunner (#32344)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -32,7 +32,7 @@ th {
|
||||
|
||||
| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Subclass |
|
||||
|---------|--------------------|--------------|---------------|-------|-----------------------|-----------|
|
||||
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE.forward_impl] |
|
||||
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE |
|
||||
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
|
||||
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
|
||||
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
|
||||
|
||||
@@ -585,6 +585,7 @@ def make_modular_kernel(
|
||||
tp_size_=get_tensor_model_parallel_world_size(),
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=1,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
|
||||
@@ -594,6 +595,7 @@ def make_modular_kernel(
|
||||
hidden_dim=config.K,
|
||||
intermediate_size_per_partition=config.N,
|
||||
num_local_experts=config.num_local_experts,
|
||||
num_logical_experts=config.E,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
|
||||
@@ -52,6 +52,7 @@ def make_dummy_moe_config(
|
||||
hidden_dim=hidden_dim,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
num_local_experts=num_experts,
|
||||
num_logical_experts=num_experts,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
activation="silu",
|
||||
in_dtype=in_dtype,
|
||||
|
||||
@@ -913,12 +913,16 @@ class FusedMoEParallelConfig:
|
||||
pcp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
sp_size: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
all2all_backend: str # all2all backend for MoE communication
|
||||
is_sequence_parallel: bool # whether sequence parallelism is used
|
||||
enable_eplb: bool # whether to enable expert load balancing
|
||||
|
||||
@property
|
||||
def is_sequence_parallel(self) -> bool:
|
||||
return self.sp_size > 1
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
@@ -974,6 +978,7 @@ class FusedMoEParallelConfig:
|
||||
tp_size_: int,
|
||||
pcp_size_: int,
|
||||
dp_size_: int,
|
||||
sp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig,
|
||||
) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
@@ -1073,9 +1078,9 @@ class FusedMoEParallelConfig:
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
sp_size=sp_size_,
|
||||
use_ep=False,
|
||||
all2all_backend=vllm_parallel_config.all2all_backend,
|
||||
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
|
||||
enable_eplb=vllm_parallel_config.enable_eplb,
|
||||
)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
@@ -1093,9 +1098,9 @@ class FusedMoEParallelConfig:
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
sp_size=sp_size_,
|
||||
use_ep=True,
|
||||
all2all_backend=vllm_parallel_config.all2all_backend,
|
||||
is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
|
||||
enable_eplb=vllm_parallel_config.enable_eplb,
|
||||
)
|
||||
|
||||
@@ -1111,10 +1116,10 @@ class FusedMoEParallelConfig:
|
||||
dp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
sp_size=1,
|
||||
use_ep=False,
|
||||
all2all_backend="naive",
|
||||
enable_eplb=False,
|
||||
is_sequence_parallel=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1126,6 +1131,7 @@ class FusedMoEConfig:
|
||||
hidden_dim: int
|
||||
intermediate_size_per_partition: int
|
||||
num_local_experts: int
|
||||
num_logical_experts: int
|
||||
activation: str
|
||||
device: torch.device | str
|
||||
routing_method: RoutingMethodType
|
||||
@@ -1175,6 +1181,14 @@ class FusedMoEConfig:
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def sp_size(self):
|
||||
return self.moe_parallel_config.sp_size
|
||||
|
||||
@property
|
||||
def is_sequence_parallel(self):
|
||||
return self.moe_parallel_config.is_sequence_parallel
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@@ -121,17 +121,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def is_monolithic(self) -> bool:
|
||||
return False
|
||||
|
||||
# @abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
|
||||
@@ -89,6 +89,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
@@ -101,5 +102,5 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Generator, Iterable
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from collections.abc import Callable, Iterable
|
||||
from enum import Enum
|
||||
from typing import Literal, cast, get_args, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -16,17 +14,10 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.parallel import ExpertPlacementStrategy
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
|
||||
from vllm.forward_context import (
|
||||
ForwardContext,
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -47,6 +38,9 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.router.router_factory import (
|
||||
create_fused_moe_router,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
|
||||
DefaultMoERunner,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
@@ -57,13 +51,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import (
|
||||
aux_stream,
|
||||
current_stream,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -264,6 +252,7 @@ def maybe_roundup_hidden_size(
|
||||
)
|
||||
|
||||
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
|
||||
|
||||
if (
|
||||
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
@@ -273,6 +262,7 @@ def maybe_roundup_hidden_size(
|
||||
current_platform.is_rocm()
|
||||
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
or current_mxfp4_backend == Mxfp4Backend.MARLIN
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
@@ -338,29 +328,15 @@ class FusedMoE(CustomOp):
|
||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||
n_shared_experts: int | None = None,
|
||||
router_logits_dtype: torch.dtype | None = None,
|
||||
has_shared_experts: bool = False,
|
||||
gate: torch.nn.Module | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
routed_input_transform: torch.nn.Module | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Allow disabling of the separate shared experts stream for
|
||||
# debug purposes.
|
||||
# TODO: Remove this after more extensive testings with TP/DP
|
||||
# and other execution modes
|
||||
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
|
||||
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
|
||||
self.shared_experts_stream = None
|
||||
else:
|
||||
# TODO(rob): enable shared expert overlap with non-cuda-alike.
|
||||
# aux_stream() returns None on non-cuda-alike platforms.
|
||||
self.shared_experts_stream = aux_stream()
|
||||
if self.shared_experts_stream is not None:
|
||||
logger.debug_once(
|
||||
"Enabled separate cuda stream for MoE shared_experts", scope="local"
|
||||
)
|
||||
|
||||
# For latent MoE: stores original hidden_states before routed_input_transform
|
||||
# so shared_experts can use it for cloning (they need original dimension)
|
||||
self._shared_experts_input: torch.Tensor | None = None
|
||||
self._gate = gate
|
||||
self._shared_experts = shared_experts
|
||||
self._routed_input_transform = routed_input_transform
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@@ -392,9 +368,12 @@ class FusedMoE(CustomOp):
|
||||
tp_size_=tp_size_,
|
||||
pcp_size_=pcp_size_,
|
||||
dp_size_=dp_size_,
|
||||
sp_size_=self.sp_size,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
|
||||
assert self.moe_parallel_config.is_sequence_parallel == is_sequence_parallel
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
self.logical_num_experts = num_experts
|
||||
|
||||
@@ -410,6 +389,7 @@ class FusedMoE(CustomOp):
|
||||
self.layer_name = prefix
|
||||
|
||||
self.enable_eplb = enable_eplb
|
||||
# TODO(bnell): should this be owned by router?
|
||||
self.eplb_state = EplbLayerState()
|
||||
self.expert_placement_strategy: ExpertPlacementStrategy = (
|
||||
vllm_config.parallel_config.expert_placement_strategy
|
||||
@@ -506,7 +486,8 @@ class FusedMoE(CustomOp):
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
|
||||
# TODO(bnell): these attributes are only used by cpu/xpu/mxfp4
|
||||
# TODO(bnell): these attributes are only used by monolithic kernels.
|
||||
# Put them in a MoERouterConfig dataclass?
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
@@ -565,6 +546,7 @@ class FusedMoE(CustomOp):
|
||||
hidden_dim=hidden_size,
|
||||
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
||||
num_local_experts=self.local_num_experts,
|
||||
num_logical_experts=self.logical_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=moe_in_dtype,
|
||||
router_logits_dtype=router_logits_dtype,
|
||||
@@ -576,9 +558,9 @@ class FusedMoE(CustomOp):
|
||||
device=vllm_config.device_config.device,
|
||||
routing_method=self.routing_method_type,
|
||||
# TODO: in_dtype == out_dtype?
|
||||
disable_inplace=disable_inplace() or has_shared_experts,
|
||||
disable_inplace=disable_inplace() or self._shared_experts is not None,
|
||||
)
|
||||
if self.use_mori_kernels:
|
||||
if self.moe_config.use_mori_kernels:
|
||||
assert self.rocm_aiter_fmoe_enabled, (
|
||||
"Mori needs to be used with aiter fused_moe for now."
|
||||
)
|
||||
@@ -641,9 +623,36 @@ class FusedMoE(CustomOp):
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
# Disable shared expert overlap if:
|
||||
# - we are using eplb with non-default backend, because of correctness issues
|
||||
# - we are using flashinfer with DP, since there nothing to gain
|
||||
# - we are using marlin kernels
|
||||
backend = self.moe_parallel_config.all2all_backend
|
||||
self.use_overlapped = (
|
||||
not (
|
||||
(self.enable_eplb and backend != "allgather_reducescatter")
|
||||
or self.moe_parallel_config.use_fi_all2allv_kernels
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
self.runner = self._init_runner()
|
||||
|
||||
def _init_runner(self):
|
||||
# Storing the runner in the FusedMoE is an intermediate state, eventually
|
||||
# the runner will own the FusedMoE layer and provide the execution interface
|
||||
# for MoE ops.
|
||||
return DefaultMoERunner(
|
||||
layer=self,
|
||||
moe_config=self.moe_config,
|
||||
router=self.router,
|
||||
routed_input_transform=self._routed_input_transform,
|
||||
gate=self.gate,
|
||||
shared_experts=self.shared_experts,
|
||||
quant_method=self.quant_method,
|
||||
reduce_results=self.reduce_results,
|
||||
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
|
||||
)
|
||||
|
||||
# Note: maybe_init_modular_kernel should only be called by
|
||||
# prepare_communication_buffer_for_model.
|
||||
@@ -673,10 +682,14 @@ class FusedMoE(CustomOp):
|
||||
self.shared_experts,
|
||||
inplace=not self.moe_config.disable_inplace,
|
||||
)
|
||||
# We need to force reconstruction of runner because we're swapping out
|
||||
# the quant_method with a FusedMoEModularMethod. This logic can go
|
||||
# away once the FusedMoEModularMethod is eliminated.
|
||||
self.runner = self._init_runner()
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> torch.nn.Module | None:
|
||||
return None
|
||||
return self._shared_experts if self.use_overlapped else None
|
||||
|
||||
@property
|
||||
def layer_id(self):
|
||||
@@ -687,53 +700,12 @@ class FusedMoE(CustomOp):
|
||||
|
||||
@property
|
||||
def gate(self) -> torch.nn.Module | None:
|
||||
return None
|
||||
|
||||
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Hook to transform hidden_states before passing to routed experts.
|
||||
For latent MoE: transforms [S, hidden_size] → [S, moe_latent_size].
|
||||
The original hidden_states is saved in _shared_experts_input so
|
||||
shared_experts still receive the original [S, hidden_size].
|
||||
|
||||
Override in subclasses (e.g., SharedFusedMoE) for latent MoE.
|
||||
"""
|
||||
return hidden_states
|
||||
|
||||
@contextmanager
|
||||
def _set_shared_experts_input(
|
||||
self, value: torch.Tensor | None
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to safely set/clear _shared_experts_input."""
|
||||
self._shared_experts_input = value
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._shared_experts_input = None
|
||||
|
||||
def _get_shared_experts_input(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Get input for shared experts.
|
||||
|
||||
For latent MoE: shared_experts need original [S, hidden_size],
|
||||
not the transformed [S, latent_size] used by routed experts.
|
||||
"""
|
||||
return (
|
||||
self._shared_experts_input
|
||||
if self._shared_experts_input is not None
|
||||
else hidden_states
|
||||
)
|
||||
return self._gate
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def pcp_size(self):
|
||||
return self.moe_parallel_config.pcp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
@@ -742,14 +714,6 @@ class FusedMoE(CustomOp):
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def pcp_rank(self):
|
||||
return self.moe_parallel_config.pcp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
@@ -758,39 +722,10 @@ class FusedMoE(CustomOp):
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_mori_kernels(self):
|
||||
return self.moe_parallel_config.use_mori_kernels
|
||||
|
||||
@property
|
||||
def use_marlin_kernels(self):
|
||||
return getattr(self.quant_method, "use_marlin", False)
|
||||
|
||||
@property
|
||||
def use_dp_chunking(self) -> bool:
|
||||
return (
|
||||
self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_parallel_config.use_mori_kernels
|
||||
or self.moe_parallel_config.use_fi_all2allv_kernels
|
||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||
|
||||
@property
|
||||
def is_internal_router(self) -> bool:
|
||||
# By default, router/gate is called before FusedMoE forward pass
|
||||
return False
|
||||
return self._gate is not None
|
||||
|
||||
def _maybe_init_expert_routing_tables(
|
||||
self,
|
||||
@@ -799,7 +734,7 @@ class FusedMoE(CustomOp):
|
||||
# with DeepEP-ll all2all backend.
|
||||
if (
|
||||
self.expert_placement_strategy != "round_robin"
|
||||
or not self.use_deepep_ll_kernels
|
||||
or not self.moe_parallel_config.use_deepep_ll_kernels
|
||||
):
|
||||
return None
|
||||
|
||||
@@ -892,48 +827,6 @@ class FusedMoE(CustomOp):
|
||||
dp_size=get_dp_group().world_size,
|
||||
)
|
||||
|
||||
def _maybe_setup_shared_experts_stream(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
has_separate_shared_experts: bool,
|
||||
use_chunked_impl: bool,
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
use_shared_experts_stream = (
|
||||
current_platform.is_cuda()
|
||||
and has_separate_shared_experts
|
||||
and not use_chunked_impl
|
||||
and self.shared_experts_stream is not None
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states_clone: torch.Tensor | None = None
|
||||
if use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
|
||||
shared_experts_input = self._get_shared_experts_input(hidden_states)
|
||||
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = shared_experts_input.clone()
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: We don't need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
# Mark sync start point for the separate shared experts
|
||||
# stream here since we want to run in parallel with the
|
||||
# router/gate (next op below)
|
||||
assert self.shared_experts_stream is not None
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
return use_shared_experts_stream, hidden_states_clone
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
shard_id: str,
|
||||
@@ -1191,7 +1084,7 @@ class FusedMoE(CustomOp):
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
if self.quant_method.__class__.__name__ in (
|
||||
if quant_method_name in (
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
):
|
||||
@@ -1488,7 +1381,7 @@ class FusedMoE(CustomOp):
|
||||
assert all(
|
||||
weight.is_contiguous()
|
||||
for name, weight in weights
|
||||
if not name.startswith("_shared_experts.")
|
||||
if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
|
||||
)
|
||||
|
||||
# Filter out the non-expert weights.
|
||||
@@ -1538,32 +1431,6 @@ class FusedMoE(CustomOp):
|
||||
self.ensure_moe_quant_config_init()
|
||||
return self.quant_method.moe_quant_config
|
||||
|
||||
def ensure_dp_chunking_init(self):
|
||||
if not self.use_dp_chunking or self.batched_hidden_states is not None:
|
||||
return
|
||||
|
||||
states_shape: tuple[int, ...]
|
||||
logits_shape: tuple[int, ...]
|
||||
|
||||
moe = self.moe_config
|
||||
|
||||
if self.vllm_config.parallel_config.enable_dbo:
|
||||
states_shape = (2, moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (2, moe.max_num_tokens, self.logical_num_experts)
|
||||
else:
|
||||
states_shape = (moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (moe.max_num_tokens, self.logical_num_experts)
|
||||
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
logits_shape,
|
||||
dtype=moe.router_logits_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
The shared_experts are typically computed using the RowParallelLinear
|
||||
@@ -1577,100 +1444,24 @@ class FusedMoE(CustomOp):
|
||||
Therefore it is required that we reduce the shared_experts output
|
||||
early.
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
isinstance(self.quant_method, FusedMoEModularMethod)
|
||||
and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
|
||||
)
|
||||
return self.runner.must_reduce_shared_expert_outputs()
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
"""
|
||||
Some combine kernels reduce across GPU ranks by default.
|
||||
"""
|
||||
if self.must_reduce_shared_expert_outputs():
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return self.runner.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# For latent MoE: save ORIGINAL hidden_states before transform
|
||||
# (shared_experts need original dimension, routed experts use transformed)
|
||||
original_hidden_states = hidden_states
|
||||
original_hidden_dim = hidden_states.shape[-1]
|
||||
|
||||
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||
hidden_states = self.apply_routed_input_transform(hidden_states)
|
||||
|
||||
# This is the dimension after transform (for routed expert output slicing)
|
||||
transformed_hidden_dim = hidden_states.shape[-1]
|
||||
if self.hidden_size != transformed_hidden_dim:
|
||||
hidden_states = F.pad(
|
||||
hidden_states,
|
||||
(0, self.hidden_size - transformed_hidden_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if (
|
||||
not self.is_sequence_parallel
|
||||
and not self.use_dp_chunking
|
||||
and self.reduce_results
|
||||
and (self.tp_size > 1 or self.ep_size > 1)
|
||||
):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
return states
|
||||
|
||||
def encode_layer_name() -> str:
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().all_moe_layers is not None
|
||||
):
|
||||
return "from_forward_context"
|
||||
return self.layer_name
|
||||
|
||||
if self.shared_experts is None:
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped forward_impl.
|
||||
fused_output = self.forward_impl(hidden_states, router_logits)
|
||||
assert not isinstance(fused_output, tuple)
|
||||
else:
|
||||
fused_output = torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits, encode_layer_name()
|
||||
)
|
||||
return reduce_output(fused_output)[..., :transformed_hidden_dim]
|
||||
else:
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped forward_impl.
|
||||
with self._set_shared_experts_input(original_hidden_states):
|
||||
shared_output, fused_output = self.forward_impl(
|
||||
hidden_states, router_logits
|
||||
)
|
||||
else:
|
||||
# Custom op handles setting/clearing _shared_experts_input internally
|
||||
# We pass original tensor for shared experts (not transformed)
|
||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
encode_layer_name(),
|
||||
original_hidden_states,
|
||||
)
|
||||
|
||||
# shared_output uses original dimension (before transform)
|
||||
# fused_output uses transformed dimension (after transform)
|
||||
return (
|
||||
reduce_output(shared_output)[..., :original_hidden_dim],
|
||||
reduce_output(fused_output)[..., :transformed_hidden_dim],
|
||||
)
|
||||
self.ensure_moe_quant_config_init()
|
||||
return self.runner.forward(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
)
|
||||
|
||||
@property
|
||||
def expert_map(self) -> torch.Tensor | None:
|
||||
@@ -1685,312 +1476,6 @@ class FusedMoE(CustomOp):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_native(hidden_states, router_logits)
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor,
|
||||
has_separate_shared_experts: bool,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
|
||||
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
|
||||
)
|
||||
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
|
||||
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
|
||||
)
|
||||
# Check size compatibility.
|
||||
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
|
||||
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
|
||||
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
# This is only true when DBO has been enabled in the config.
|
||||
# Both tensors will have an outer dimension for the ubatch id
|
||||
if self.batched_hidden_states.dim() == 3:
|
||||
assert self.batched_router_logits.dim() == 3
|
||||
batch_buffer_idx = dbo_current_ubatch_id()
|
||||
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
|
||||
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
|
||||
else:
|
||||
batched_hidden_states = self.batched_hidden_states
|
||||
batched_router_logits = self.batched_router_logits
|
||||
|
||||
assert (
|
||||
batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size
|
||||
)
|
||||
assert (
|
||||
batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size
|
||||
)
|
||||
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
|
||||
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=self,
|
||||
x=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=staged_hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
shared_output = self.shared_experts(staged_hidden_states)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
if not skip_result_store:
|
||||
if self.shared_experts is None:
|
||||
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states, non_blocking=True
|
||||
)
|
||||
else:
|
||||
full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states[0], non_blocking=True
|
||||
)
|
||||
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states[1], non_blocking=True
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
# If the input to the MoE is sequence parallel then divide by sp_size
|
||||
# to find the maximum number of tokens for any individual dispatcher.
|
||||
if self.is_sequence_parallel:
|
||||
max_tokens_across_dispatchers = cdiv(
|
||||
max_tokens_across_dispatchers, self.sp_size
|
||||
)
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
|
||||
):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(
|
||||
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
|
||||
)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
with ctx.dp_metadata.chunked_sizes(
|
||||
self.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
|
||||
):
|
||||
process_chunk(
|
||||
chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
|
||||
)
|
||||
|
||||
if self.shared_experts is None:
|
||||
return full_fused_final_hidden_states
|
||||
else:
|
||||
return (full_shared_final_hidden_states, full_fused_final_hidden_states)
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.ensure_moe_quant_config_init()
|
||||
self.ensure_dp_chunking_init()
|
||||
|
||||
has_separate_shared_experts = (
|
||||
not self.quant_method.mk_owns_shared_expert
|
||||
and self.shared_experts is not None
|
||||
)
|
||||
|
||||
use_chunked_impl = self.use_dp_chunking
|
||||
|
||||
use_shared_experts_stream, hidden_states_clone = (
|
||||
self._maybe_setup_shared_experts_stream(
|
||||
hidden_states, has_separate_shared_experts, use_chunked_impl
|
||||
)
|
||||
)
|
||||
|
||||
# If router/gate provided, then apply it here.
|
||||
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||
# parallel execution of shared experts with the FusedMoE via
|
||||
# separate cuda stream)
|
||||
if self.gate is not None:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
if use_chunked_impl:
|
||||
return self.forward_impl_chunked(
|
||||
hidden_states, router_logits, has_separate_shared_experts
|
||||
)
|
||||
|
||||
# NOTE(rob): once we finish migrating all the quant methods to use
|
||||
# MKs, we can remove the naive dispatch/combine path from here.
|
||||
do_naive_dispatch_combine = (
|
||||
self.dp_size > 1 and not self.quant_method.supports_internal_mk
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
sp_ctx = (
|
||||
ctx.dp_metadata.sp_local_sizes(self.sp_size)
|
||||
if ctx.dp_metadata
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with sp_ctx:
|
||||
extra_tensors = None
|
||||
if do_naive_dispatch_combine:
|
||||
post_quant_allgather = (
|
||||
self.quant_method is not None
|
||||
and self.dp_size > 1
|
||||
and self.use_ep
|
||||
and getattr(self.quant_method, "do_post_quant_allgather", False)
|
||||
)
|
||||
if post_quant_allgather:
|
||||
hidden_states_to_dispatch, extra_tensors = (
|
||||
self.quant_method.prepare_dp_allgather_tensor(
|
||||
self, hidden_states, router_logits
|
||||
)
|
||||
)
|
||||
else:
|
||||
hidden_states_to_dispatch = hidden_states
|
||||
|
||||
dispatch_res = get_ep_group().dispatch_router_logits(
|
||||
hidden_states_to_dispatch,
|
||||
router_logits,
|
||||
self.is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
if extra_tensors is not None:
|
||||
(
|
||||
orig_hidden_states,
|
||||
router_logits,
|
||||
extra_tensors_combined,
|
||||
) = dispatch_res
|
||||
hidden_states_combined = (
|
||||
orig_hidden_states,
|
||||
extra_tensors_combined[0],
|
||||
)
|
||||
else:
|
||||
hidden_states_combined, router_logits = dispatch_res
|
||||
orig_hidden_states = hidden_states_combined
|
||||
else:
|
||||
orig_hidden_states = hidden_states
|
||||
|
||||
# Run shared experts before matrix multiply.
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
assert self.shared_experts is not None
|
||||
shared_input = self._get_shared_experts_input(hidden_states)
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
# we should modify All2AllManager abstract to better support PCP.
|
||||
if self.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
)
|
||||
router_logits = get_pcp_group().all_gather(
|
||||
router_logits,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Matrix multiply.
|
||||
x = hidden_states_combined if do_naive_dispatch_combine else hidden_states
|
||||
|
||||
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
|
||||
# Figure out nicer way to do this.
|
||||
x_orig = orig_hidden_states if do_naive_dispatch_combine else hidden_states
|
||||
|
||||
if self.quant_method.is_monolithic:
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=self,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=x_orig,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=x, # The type signture of this is wrong due to the hack.
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if use_shared_experts_stream:
|
||||
# Run shared experts in parallel on a separate stream
|
||||
# NOTE: We start the separate stream here and mark the
|
||||
# sync end point immediately after it is done. This is
|
||||
# important to avoid excessive stream allocations by the cuda
|
||||
# graph replay later.
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that hidden_states clone() is necessary here to avoid
|
||||
# conflict with the main stream
|
||||
shared_output = self.shared_experts(hidden_states_clone)
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
def combine_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
states = get_ep_group().combine(states, self.is_sequence_parallel)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
states = get_pcp_group().reduce_scatter(
|
||||
states,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
final_hidden_states[0],
|
||||
combine_output(final_hidden_states[1]),
|
||||
)
|
||||
else:
|
||||
return combine_output(final_hidden_states)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
@@ -2051,94 +1536,6 @@ class FusedMoE(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
def get_layer_from_name(layer_name: str) -> FusedMoE:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if layer_name == "from_forward_context":
|
||||
all_moe_layers = forward_context.all_moe_layers
|
||||
assert all_moe_layers is not None
|
||||
moe_layer_index = forward_context.moe_layer_index
|
||||
if moe_layer_index >= len(all_moe_layers):
|
||||
raise AssertionError(
|
||||
"We expected the number of MOE layers in `all_moe_layers` "
|
||||
"to be equal to the number of "
|
||||
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
|
||||
)
|
||||
layer_name = all_moe_layers[moe_layer_index]
|
||||
forward_context.moe_layer_index += 1
|
||||
self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
|
||||
return self
|
||||
|
||||
|
||||
def moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
self = get_layer_from_name(layer_name)
|
||||
assert self.shared_experts is None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward",
|
||||
op_func=moe_forward,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self = get_layer_from_name(layer_name)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
# Set here because torch.compile skips forward_native() setup code
|
||||
# and calls this op directly. forward_impl() reads from this var.
|
||||
with self._set_shared_experts_input(shared_experts_input):
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Output shapes:
|
||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||
# - shared_out: same as shared_experts_input if provided, else same as hidden_states
|
||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
|
||||
if shared_experts_input is not None:
|
||||
shared_out = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward_shared",
|
||||
op_func=moe_forward_shared,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_shared_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
|
||||
# to avoid expensive runtime reflection in model loading code
|
||||
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
|
||||
|
||||
@@ -1228,7 +1228,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The _finalize method is a wrapper around self.prepare_finalize.finalize
|
||||
|
||||
2
vllm/model_executor/layers/fused_moe/runner/__init__.py
Normal file
2
vllm/model_executor/layers/fused_moe/runner/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
@@ -0,0 +1,743 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.forward_context import (
|
||||
ForwardContext,
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import (
|
||||
aux_stream,
|
||||
current_stream,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_layer_from_name(layer_name: str) -> torch.nn.Module:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if layer_name == "from_forward_context":
|
||||
all_moe_layers = forward_context.all_moe_layers
|
||||
assert all_moe_layers is not None
|
||||
moe_layer_index = forward_context.moe_layer_index
|
||||
if moe_layer_index >= len(all_moe_layers):
|
||||
raise AssertionError(
|
||||
"We expected the number of MOE layers in `all_moe_layers` "
|
||||
"to be equal to the number of "
|
||||
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
|
||||
)
|
||||
layer_name = all_moe_layers[moe_layer_index]
|
||||
forward_context.moe_layer_index += 1
|
||||
return forward_context.no_compile_layers[layer_name]
|
||||
|
||||
|
||||
def _moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
return layer.runner.forward_impl(
|
||||
layer, hidden_states, router_logits, shared_experts_input
|
||||
)
|
||||
|
||||
|
||||
def _moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def _moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
return layer.runner.forward_impl(
|
||||
layer, hidden_states, router_logits, shared_experts_input
|
||||
)
|
||||
|
||||
|
||||
def _moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Output shapes:
|
||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||
# - shared_out: same as shared_experts_input if provided, else same as
|
||||
# hidden_states
|
||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
|
||||
if shared_experts_input is not None:
|
||||
shared_out = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward",
|
||||
op_func=_moe_forward,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=_moe_forward_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward_shared",
|
||||
op_func=_moe_forward_shared,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=_moe_forward_shared_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
class DefaultMoERunner(MoERunner):
|
||||
"""
|
||||
Default implementation of the MoE runner for executing Mixture of Experts layers.
|
||||
|
||||
This class provides a comprehensive implementation for running MoE computations
|
||||
with support for:
|
||||
- Expert routing and token dispatching
|
||||
- Shared experts computation with optional parallel execution using CUDA streams
|
||||
- Data parallel (DP) chunking for large batch processing
|
||||
- Tensor model parallel and expert parallel operations
|
||||
- Various quantization methods and custom operators
|
||||
- Both monolithic and decomposed expert execution paths
|
||||
|
||||
The runner handles the complete MoE forward pass including routing tokens to
|
||||
experts, executing expert computations, and combining results. It supports
|
||||
advanced features like overlapped execution of shared experts and optimized
|
||||
kernels for different parallel execution modes.
|
||||
|
||||
Eventually, this class will be split up and specialized for different
|
||||
configurations, e.g. the presense or absence of shared experts, a gate, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
moe_config: FusedMoEConfig,
|
||||
router: FusedMoERouter,
|
||||
routed_input_transform: torch.nn.Module | None,
|
||||
gate: torch.nn.Module | None,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
quant_method: FusedMoEMethodBase,
|
||||
reduce_results: bool,
|
||||
enable_dbo: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.moe_config = moe_config
|
||||
self.router = router
|
||||
self.routed_input_transform = routed_input_transform
|
||||
self.gate = gate
|
||||
self.shared_experts = shared_experts
|
||||
self.quant_method = quant_method
|
||||
self.reduce_results = reduce_results
|
||||
self.enable_dbo = enable_dbo
|
||||
|
||||
# Allow disabling of the separate shared experts stream for
|
||||
# debug purposes.
|
||||
# TODO: Remove this after more extensive testings with TP/DP
|
||||
# and other execution modes
|
||||
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
|
||||
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
|
||||
self.shared_experts_stream = None
|
||||
else:
|
||||
# TODO(rob): enable shared expert overlap with non-cuda-alike.
|
||||
# aux_stream() returns None on non-cuda-alike platforms.
|
||||
self.shared_experts_stream = aux_stream()
|
||||
if self.shared_experts_stream is not None:
|
||||
logger.debug_once(
|
||||
"Enabled separate cuda stream for MoE shared_experts", scope="local"
|
||||
)
|
||||
|
||||
# Needed for string -> FusedMoE layer lookup in custom ops.
|
||||
self.layer_name = layer.layer_name
|
||||
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped forward_impl.
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = _moe_forward
|
||||
else:
|
||||
self.moe_forward = _moe_forward_shared
|
||||
else:
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward
|
||||
else:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward_shared
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
|
||||
@property
|
||||
def use_dp_chunking(self) -> bool:
|
||||
return (
|
||||
self.moe_config.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.moe_parallel_config.use_mori_kernels
|
||||
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
|
||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||
|
||||
def _maybe_setup_shared_experts_stream(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
has_separate_shared_experts: bool,
|
||||
use_chunked_impl: bool,
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
use_shared_experts_stream = (
|
||||
current_platform.is_cuda()
|
||||
and has_separate_shared_experts
|
||||
and not use_chunked_impl
|
||||
and self.shared_experts_stream is not None
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states_clone: torch.Tensor | None = None
|
||||
if use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
|
||||
shared_experts_input = (
|
||||
shared_input if shared_input is not None else hidden_states
|
||||
)
|
||||
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = shared_experts_input.clone()
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: We don't need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
# Mark sync start point for the separate shared experts
|
||||
# stream here since we want to run in parallel with the
|
||||
# router/gate (next op below)
|
||||
assert self.shared_experts_stream is not None
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
return use_shared_experts_stream, hidden_states_clone
|
||||
|
||||
def ensure_dp_chunking_init(self):
|
||||
if not self.use_dp_chunking or self.batched_hidden_states is not None:
|
||||
return
|
||||
|
||||
states_shape: tuple[int, ...]
|
||||
logits_shape: tuple[int, ...]
|
||||
|
||||
moe = self.moe_config
|
||||
|
||||
if self.enable_dbo:
|
||||
states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
|
||||
logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
|
||||
else:
|
||||
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
|
||||
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
|
||||
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
logits_shape,
|
||||
dtype=moe.router_logits_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
The shared_experts are typically computed using the RowParallelLinear
|
||||
layer. The result of this function is typically used as
|
||||
the reduce_results argument to the module.
|
||||
When just tensor-parallel is used, it is not required to reduce
|
||||
the shared_experts results immediately. Instead we reduce at the
|
||||
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
|
||||
With EP and all2all kernels - this is no longer viable as all
|
||||
GPU ranks in DP, produce the complete set of hidden_states.
|
||||
Therefore it is required that we reduce the shared_experts output
|
||||
early.
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
self.quant_method.moe_mk is not None
|
||||
and self.quant_method.moe_mk.output_is_reduced()
|
||||
)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
"""
|
||||
Some combine kernels reduce across GPU ranks by default.
|
||||
"""
|
||||
if self.must_reduce_shared_expert_outputs():
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply transform for routed experts (e.g., latent projection).
|
||||
|
||||
This is called by FusedMoE.forward_native. The original hidden_states
|
||||
is saved separately so shared experts get [S, hidden_size] while
|
||||
routed experts get the transformed [S, moe_latent_size].
|
||||
|
||||
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
|
||||
moved inside SharedFusedMoE to all-reduce on the smaller latent
|
||||
dimension.
|
||||
"""
|
||||
if self.routed_input_transform is not None:
|
||||
result = self.routed_input_transform(hidden_states)
|
||||
# ReplicatedLinear returns (output, extra_bias) tuple.
|
||||
# We only need the output tensor; extra_bias is not used here.
|
||||
if isinstance(result, tuple):
|
||||
return result[0]
|
||||
return result
|
||||
return hidden_states
|
||||
|
||||
def _reduce_output(
|
||||
self,
|
||||
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
trunc_sizes: list[int],
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
|
||||
return x[..., :trunc_size]
|
||||
|
||||
def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
|
||||
return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)
|
||||
|
||||
if (
|
||||
not self.moe_config.is_sequence_parallel
|
||||
and not self.use_dp_chunking
|
||||
and self.reduce_results
|
||||
and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
|
||||
):
|
||||
func = reduce_and_trunc
|
||||
else:
|
||||
func = trunc
|
||||
|
||||
if isinstance(states, tuple):
|
||||
return tuple(
|
||||
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
|
||||
)
|
||||
else:
|
||||
assert len(trunc_sizes) == 1
|
||||
return func(states, trunc_sizes[0])
|
||||
|
||||
def _encode_layer_name(self) -> str:
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().all_moe_layers is not None
|
||||
):
|
||||
return "from_forward_context"
|
||||
return self.layer_name
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# For latent MoE: save ORIGINAL hidden_states before transform
|
||||
# (shared_experts need original dimension, routed experts use transformed)
|
||||
original_hidden_states = hidden_states
|
||||
original_hidden_dim = hidden_states.shape[-1]
|
||||
|
||||
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||
hidden_states = self.apply_routed_input_transform(hidden_states)
|
||||
|
||||
# This is the dimension after transform (for routed expert output slicing)
|
||||
transformed_hidden_dim = hidden_states.shape[-1]
|
||||
if self.moe_config.hidden_dim != transformed_hidden_dim:
|
||||
hidden_states = F.pad(
|
||||
hidden_states,
|
||||
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
fused_output = self.moe_forward(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
original_hidden_states,
|
||||
self._encode_layer_name(),
|
||||
)
|
||||
|
||||
if isinstance(fused_output, tuple):
|
||||
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
|
||||
else:
|
||||
orig_hidden_dims = [transformed_hidden_dim]
|
||||
|
||||
return self._reduce_output(fused_output, orig_hidden_dims)
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
has_separate_shared_experts: bool,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
|
||||
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
|
||||
)
|
||||
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
|
||||
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
|
||||
)
|
||||
# Check size compatibility.
|
||||
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
|
||||
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
|
||||
|
||||
# TODO(bnell): Fix shared_expert_inputs w/chunking.
|
||||
# assert shared_input is None, (
|
||||
# "Routed input transform is not currently supported with DP chunking."
|
||||
# )
|
||||
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
# This is only true when DBO has been enabled in the config.
|
||||
# Both tensors will have an outer dimension for the ubatch id
|
||||
if self.batched_hidden_states.dim() == 3:
|
||||
assert self.batched_router_logits.dim() == 3
|
||||
batch_buffer_idx = dbo_current_ubatch_id()
|
||||
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
|
||||
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
|
||||
else:
|
||||
batched_hidden_states = self.batched_hidden_states
|
||||
batched_router_logits = self.batched_router_logits
|
||||
|
||||
assert (
|
||||
batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size
|
||||
)
|
||||
assert (
|
||||
batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size
|
||||
)
|
||||
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
|
||||
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=staged_hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_input,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
shared_output = self.shared_experts(staged_hidden_states)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
if not skip_result_store:
|
||||
if self.shared_experts is None:
|
||||
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states, non_blocking=True
|
||||
)
|
||||
else:
|
||||
full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states[0], non_blocking=True
|
||||
)
|
||||
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states[1], non_blocking=True
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
# If the input to the MoE is sequence parallel then divide by sp_size
|
||||
# to find the maximum number of tokens for any individual dispatcher.
|
||||
if self.moe_config.is_sequence_parallel:
|
||||
max_tokens_across_dispatchers = cdiv(
|
||||
max_tokens_across_dispatchers, self.moe_config.sp_size
|
||||
)
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
|
||||
):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(
|
||||
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
|
||||
)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
with ctx.dp_metadata.chunked_sizes(
|
||||
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
|
||||
):
|
||||
process_chunk(
|
||||
chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
|
||||
)
|
||||
|
||||
if self.shared_experts is None:
|
||||
return full_fused_final_hidden_states
|
||||
else:
|
||||
return (full_shared_final_hidden_states, full_fused_final_hidden_states)
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.ensure_dp_chunking_init()
|
||||
|
||||
has_separate_shared_experts = (
|
||||
not self.quant_method.mk_owns_shared_expert
|
||||
and self.shared_experts is not None
|
||||
)
|
||||
|
||||
use_chunked_impl = self.use_dp_chunking
|
||||
|
||||
use_shared_experts_stream, hidden_states_clone = (
|
||||
self._maybe_setup_shared_experts_stream(
|
||||
hidden_states,
|
||||
shared_input,
|
||||
has_separate_shared_experts,
|
||||
use_chunked_impl,
|
||||
)
|
||||
)
|
||||
|
||||
# If router/gate provided, then apply it here.
|
||||
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||
# parallel execution of shared experts with the FusedMoE via
|
||||
# separate cuda stream)
|
||||
if self.gate is not None:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
if use_chunked_impl:
|
||||
return self.forward_impl_chunked(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_input,
|
||||
has_separate_shared_experts,
|
||||
)
|
||||
|
||||
# NOTE(rob): once we finish migrating all the quant methods to use
|
||||
# MKs, we can remove the naive dispatch/combine path from here.
|
||||
do_naive_dispatch_combine = (
|
||||
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
sp_ctx = (
|
||||
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
|
||||
if ctx.dp_metadata
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with sp_ctx:
|
||||
extra_tensors = None
|
||||
if do_naive_dispatch_combine:
|
||||
post_quant_allgather = (
|
||||
self.quant_method is not None
|
||||
and self.moe_config.dp_size > 1
|
||||
and self.moe_config.use_ep
|
||||
and getattr(self.quant_method, "do_post_quant_allgather", False)
|
||||
)
|
||||
if post_quant_allgather:
|
||||
hidden_states_to_dispatch, extra_tensors = (
|
||||
self.quant_method.prepare_dp_allgather_tensor(
|
||||
layer, hidden_states, router_logits
|
||||
)
|
||||
)
|
||||
else:
|
||||
hidden_states_to_dispatch = hidden_states
|
||||
|
||||
dispatch_res = get_ep_group().dispatch_router_logits(
|
||||
hidden_states_to_dispatch,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
if extra_tensors is not None:
|
||||
(
|
||||
orig_hidden_states,
|
||||
router_logits,
|
||||
extra_tensors_combined,
|
||||
) = dispatch_res
|
||||
hidden_states_combined = (
|
||||
orig_hidden_states,
|
||||
extra_tensors_combined[0],
|
||||
)
|
||||
else:
|
||||
hidden_states_combined, router_logits = dispatch_res
|
||||
orig_hidden_states = hidden_states_combined
|
||||
else:
|
||||
orig_hidden_states = hidden_states
|
||||
|
||||
# Run shared experts before matrix multiply.
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
assert self.shared_experts is not None
|
||||
shared_input = (
|
||||
shared_input if shared_input is not None else hidden_states
|
||||
)
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
# we should modify All2AllManager abstract to better support PCP.
|
||||
if self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
)
|
||||
router_logits = get_pcp_group().all_gather(
|
||||
router_logits,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
|
||||
# Figure out nicer way to do this.
|
||||
if do_naive_dispatch_combine:
|
||||
x = hidden_states_combined
|
||||
x_orig = orig_hidden_states
|
||||
else:
|
||||
x = hidden_states
|
||||
x_orig = hidden_states
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=x_orig,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=x, # The type signture of this is wrong due to the hack.
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_input,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if use_shared_experts_stream:
|
||||
# Run shared experts in parallel on a separate stream
|
||||
# NOTE: We start the separate stream here and mark the
|
||||
# sync end point immediately after it is done. This is
|
||||
# important to avoid excessive stream allocations by the cuda
|
||||
# graph replay later.
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that hidden_states clone() is necessary here to avoid
|
||||
# conflict with the main stream
|
||||
shared_output = self.shared_experts(hidden_states_clone)
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
def combine_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
states = get_ep_group().combine(
|
||||
states, self.moe_config.is_sequence_parallel
|
||||
)
|
||||
|
||||
if self.moe_config.pcp_size > 1:
|
||||
states = get_pcp_group().reduce_scatter(
|
||||
states,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
final_hidden_states[0],
|
||||
combine_output(final_hidden_states[1]),
|
||||
)
|
||||
else:
|
||||
return combine_output(final_hidden_states)
|
||||
34
vllm/model_executor/layers/fused_moe/runner/moe_runner.py
Normal file
34
vllm/model_executor/layers/fused_moe/runner/moe_runner.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MoERunner(ABC):
|
||||
"""
|
||||
Abstract base class for Mixture of Experts (MoE) runners.
|
||||
|
||||
This class defines the interface that all MoE runner implementations must follow.
|
||||
MoE runners are responsible for executing the forward pass of MoE layers, handling
|
||||
expert routing, and managing tensor parallel operations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def maybe_all_reduce_tensor_model_parallel(
|
||||
self,
|
||||
final_hidden_states: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError
|
||||
@@ -18,70 +18,6 @@ class SharedFusedMoE(FusedMoE):
|
||||
can be interleaved with the fused all2all dispatch communication step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
gate: torch.nn.Module | None = None,
|
||||
use_overlapped: bool = True,
|
||||
routed_input_transform: torch.nn.Module | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Pass has_shared_experts so FusedMoE.__init__ can set disable_inplace
|
||||
# without accessing self.shared_experts (submodules cannot be set before
|
||||
# Module.__init__()).
|
||||
kwargs["has_shared_experts"] = shared_experts is not None
|
||||
super().__init__(**kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self._routed_input_transform = routed_input_transform
|
||||
|
||||
# Disable shared expert overlap if:
|
||||
# - we are using eplb with non-default backend, because of correctness issues
|
||||
# - we are using flashinfer with DP, since there nothing to gain
|
||||
# - we are using marlin kernels
|
||||
backend = self.moe_parallel_config.all2all_backend
|
||||
self.use_overlapped = (
|
||||
use_overlapped
|
||||
and not (
|
||||
(self.enable_eplb and backend != "allgather_reducescatter")
|
||||
or self.moe_parallel_config.use_fi_all2allv_kernels
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
self._gate = gate
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> torch.nn.Module | None:
|
||||
return self._shared_experts if self.use_overlapped else None
|
||||
|
||||
@property
|
||||
def gate(self) -> torch.nn.Module | None:
|
||||
return self._gate if self.use_overlapped else None
|
||||
|
||||
@property
|
||||
def is_internal_router(self) -> bool:
|
||||
return self.gate is not None
|
||||
|
||||
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply transform for routed experts (e.g., latent projection).
|
||||
|
||||
This is called by FusedMoE.forward_native. The original hidden_states
|
||||
is saved separately so shared experts get [S, hidden_size] while
|
||||
routed experts get the transformed [S, moe_latent_size].
|
||||
|
||||
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
|
||||
moved inside SharedFusedMoE to all-reduce on the smaller latent
|
||||
dimension.
|
||||
"""
|
||||
if self._routed_input_transform is not None:
|
||||
result = self._routed_input_transform(hidden_states)
|
||||
# ReplicatedLinear returns (output, extra_bias) tuple.
|
||||
# We only need the output tensor; extra_bias is not used here.
|
||||
if isinstance(result, tuple):
|
||||
return result[0]
|
||||
return result
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -55,6 +55,8 @@ logger = init_logger(__name__)
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
# --8<-- [end:unquantized_fused_moe]
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.unquantized_backend = select_unquantized_moe_backend(
|
||||
@@ -90,8 +92,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_cuda(layer, x, topk_weights, topk_ids)
|
||||
return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
@@ -293,12 +296,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
@@ -316,6 +321,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel is not None
|
||||
|
||||
@@ -329,6 +335,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
def forward_monolithic_cuda(
|
||||
|
||||
@@ -764,6 +764,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
|
||||
@@ -501,6 +501,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
|
||||
@@ -349,6 +349,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
@@ -361,7 +362,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -645,6 +646,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
@@ -673,7 +675,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -1064,6 +1066,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
assert self.moe_mk is not None
|
||||
@@ -1079,7 +1082,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1203,6 +1206,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@@ -1713,6 +1717,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel_backend == "Marlin"
|
||||
return fused_marlin_moe(
|
||||
@@ -1961,6 +1966,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@@ -2575,6 +2581,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -140,6 +140,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
|
||||
@@ -1010,6 +1010,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
assert not self.is_monolithic
|
||||
@@ -1023,7 +1024,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -635,6 +635,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
if layer.apply_router_weight_on_input:
|
||||
|
||||
@@ -900,6 +900,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
|
||||
@@ -958,6 +958,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
@@ -980,7 +981,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -1524,6 +1525,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
@@ -1551,7 +1553,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=layer._get_shared_experts_input(x),
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -367,6 +367,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
|
||||
@@ -900,6 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
if layer.enable_eplb:
|
||||
|
||||
@@ -419,6 +419,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
@@ -607,6 +608,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
@@ -977,6 +979,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.emulate:
|
||||
if (
|
||||
|
||||
@@ -816,10 +816,14 @@ class Worker(WorkerBase):
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tp_group().world_size,
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
|
||||
Reference in New Issue
Block a user