diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 4b693d8c8..2ef4424c2 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -603,7 +603,6 @@ def make_shared_experts( def modular_triton_fused_moe( moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - shared_experts: torch.nn.Module | None = None, ) -> FusedMoEKernel: return FusedMoEKernel( maybe_make_prepare_finalize( @@ -613,6 +612,5 @@ def modular_triton_fused_moe( use_monolithic=False, ), TritonExperts(moe_config, quant_config), - shared_experts, inplace=False, ) diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh index 3372dd10f..c3deb7d60 100755 --- a/tools/ep_kernels/install_python_libraries.sh +++ b/tools/ep_kernels/install_python_libraries.sh @@ -103,6 +103,7 @@ pushd "$WORKSPACE" echo "Downloading NVSHMEM ${NVSHMEM_VER} for ${NVSHMEM_SUBDIR} ..." curl -fSL "${NVSHMEM_URL}" -o "${NVSHMEM_FILE}" tar -xf "${NVSHMEM_FILE}" +rm -rf nvshmem mv "${NVSHMEM_FILE%.tar.xz}" nvshmem rm -f "${NVSHMEM_FILE}" rm -rf nvshmem/lib/bin nvshmem/lib/share diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py index 91f7b91f5..a316a54bd 100644 --- a/vllm/distributed/elastic_ep/elastic_execute.py +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -410,8 +410,7 @@ class ElasticEPScalingExecutor: # for the new EP size by resetting quant_method to base for module in moe_modules: if hasattr(module.quant_method, "old_quant_method"): - module.quant_method = module.quant_method.old_quant_method - module.runner = module._init_runner() + module._replace_quant_method(module.quant_method.old_quant_method) prepare_communication_buffer_for_model(self.worker.model_runner.model) eplb_model_state.communicator = create_eplb_communicator( diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 0a8ce562e..01efe3e47 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -595,10 +595,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs): return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs) - @property - def _shared_experts(self): - return self.base_layer._shared_experts - @property def quant_method(self): return self.base_layer.quant_method diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 5b5835392..0c93dc6a7 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -937,6 +937,15 @@ class FusedMoEParallelConfig: all2all_backend: str # all2all backend for MoE communication enable_eplb: bool # whether to enable expert load balancing + @property + def use_dp_chunking(self) -> bool: + return ( + self.use_deepep_ll_kernels + or self.use_mori_kernels + or self.use_fi_nvl_two_sided_kernels + or self.use_nixl_ep_kernels + ) and envs.VLLM_ENABLE_MOE_DP_CHUNK + @property def is_sequence_parallel(self) -> bool: return self.sp_size > 1 diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 75ee77664..43082b367 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1194,6 +1194,8 @@ def cutlass_moe_w4a8_fp8( quant_config=quant_config, group_size=group_size, ), + shared_experts=None, + inplace=False, ) return fn.apply( diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 4cb12a8c1..aa78e10b6 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -53,6 +53,7 @@ class TrtLlmFp8ExpertsBase: self.local_num_experts = moe_config.num_local_experts self.ep_rank = moe_config.moe_parallel_config.ep_rank + self.moe_config = moe_config self.quant_config = quant_config @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index d951439d3..a239dfea9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -40,9 +40,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): def mk_owns_shared_expert(self) -> bool: # NOTE(rob): temporary attribute to indicate support for # completed migration to the new internal MK interface. - return ( - self.moe_kernel is not None and self.moe_kernel.shared_experts is not None - ) + return self.moe_kernel is not None and self.moe_kernel.owns_shared_experts @abstractmethod def create_weights( @@ -163,7 +161,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: raise NotImplementedError def apply_monolithic( @@ -171,5 +169,5 @@ class FusedMoEMethodBase(QuantizeMethodBase): layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 0065c11f3..142e18078 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEKernel, FusedMoEPrepareAndFinalizeModular, ) +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, +) logger = init_logger(__name__) @@ -44,7 +47,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): moe_layer: torch.nn.Module, old_quant_method: FusedMoEMethodBase, prepare_finalize: FusedMoEPrepareAndFinalizeModular, - shared_experts: torch.nn.Module | None, + shared_experts: SharedExperts | None, inplace: bool = False, ) -> "FusedMoEModularMethod": return FusedMoEModularMethod( @@ -52,8 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): FusedMoEKernel( prepare_finalize, old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), - shared_experts, - moe_parallel_config=moe_layer.moe_parallel_config, + shared_experts=shared_experts, inplace=inplace, ), ) @@ -89,7 +91,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( hidden_states=x, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f56edce22..2fb61615b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -42,6 +42,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import ( from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import ( DefaultMoERunner, ) +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, +) from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, ) @@ -275,8 +278,6 @@ class FusedMoE(CustomOp): ): super().__init__() - self._gate = gate - self._shared_experts = shared_experts self._routed_input_transform = routed_input_transform if params_dtype is None: @@ -486,7 +487,7 @@ class FusedMoE(CustomOp): device=vllm_config.device_config.device, routing_method=self.routing_method_type, # TODO: in_dtype == out_dtype? - disable_inplace=disable_inplace() or self._shared_experts is not None, + disable_inplace=disable_inplace() or shared_experts is not None, ) if self.moe_config.use_mori_kernels: assert self.rocm_aiter_fmoe_enabled, ( @@ -564,34 +565,20 @@ class FusedMoE(CustomOp): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + + # TODO(bnell): this is un-needed and removed in a follow up PR. self.base_quant_method = self.quant_method - # Disable shared expert overlap if: - # - we are using eplb with non-default backend, because of correctness issues - # - we are using flashinfer with DP, since there nothing to gain - # - we are using marlin kernels - backend = self.moe_parallel_config.all2all_backend - self.use_overlapped = ( - not ( - (self.enable_eplb and backend != "allgather_reducescatter") - or self.moe_parallel_config.use_fi_nvl_two_sided_kernels - ) - and self._shared_experts is not None - ) - - self.runner = self._init_runner() - - def _init_runner(self): # Storing the runner in the FusedMoE is an intermediate state, eventually # the runner will own the FusedMoE layer and provide the execution interface # for MoE ops. - return DefaultMoERunner( + self.runner = DefaultMoERunner( layer=self, moe_config=self.moe_config, router=self.router, routed_input_transform=self._routed_input_transform, - gate=self.gate, - shared_experts=self.shared_experts, + gate=gate, + shared_experts=shared_experts, quant_method=self.quant_method, reduce_results=self.reduce_results, enable_dbo=self.vllm_config.parallel_config.enable_dbo, @@ -602,10 +589,7 @@ class FusedMoE(CustomOp): # intrusive way to do this. def _replace_quant_method(self, mk: FusedMoEMethodBase): self.quant_method = mk - # We need to force reconstruction of runner because we're swapping out - # the quant_method with a FusedMoEModularMethod. This logic can go - # away once the FusedMoEModularMethod is eliminated. - self.runner = self._init_runner() + self.runner._replace_quant_method(mk) # Note: maybe_init_modular_kernel should only be called by # prepare_communication_buffer_for_model. @@ -639,8 +623,8 @@ class FusedMoE(CustomOp): ) @property - def shared_experts(self) -> torch.nn.Module | None: - return self._shared_experts if self.use_overlapped else None + def shared_experts(self) -> SharedExperts | None: + return self.runner.shared_experts @property def layer_id(self): @@ -649,10 +633,6 @@ class FusedMoE(CustomOp): return extract_layer_index(self.layer_name) - @property - def gate(self) -> torch.nn.Module | None: - return self._gate if self.use_overlapped else None - @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -676,7 +656,7 @@ class FusedMoE(CustomOp): @property def is_internal_router(self) -> bool: # By default, router/gate is called before FusedMoE forward pass - return self.gate is not None + return self.runner.is_internal_router() def _maybe_init_expert_routing_tables( self, @@ -1467,7 +1447,12 @@ class FusedMoE(CustomOp): assert all( weight.is_contiguous() for name, weight in weights - if not (name.startswith("_shared_experts.") or name.startswith("_gate.")) + if not ( + name.startswith("_shared_experts.") + or name.startswith("_gate.") + or name.startswith("_routed_input_transform.") + or name.startswith("_routed_output_transform.") + ) and name not in NON_EXPERT_WEIGHTS ) @@ -1477,8 +1462,11 @@ class FusedMoE(CustomOp): if name not in NON_EXPERT_WEIGHTS and weight.shape != torch.Size([]) and not name.startswith("_shared_experts.") - # exclude parameters from non-expert submodules (e.g. gate/shared) + # exclude parameters from non-expert submodules, + # e.g. gate/shared/transforms. and not name.startswith("_gate.") + and not name.startswith("_routed_input_transform.") + and not name.startswith("_routed_output_transform.") ] def set_eplb_state( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 90379b834..f2e6e2560 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -21,6 +21,10 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, + SharedExpertsOrder, +) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, disable_inplace, @@ -235,6 +239,13 @@ class FusedMoEPrepareAndFinalize(ABC): """ raise NotImplementedError + def supports_async(self) -> bool: + """ + Indicates whether or not this class implements prepare_async and + finalize_async. + """ + return False + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize): @@ -281,13 +292,6 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize): """ raise NotImplementedError - def supports_async(self) -> bool: - """ - Indicates whether or not this class implements prepare_async and - finalize_async. - """ - return False - def prepare_async( self, a1: torch.Tensor, @@ -1003,15 +1007,20 @@ class FusedMoEKernelModularImpl: self, prepare_finalize: FusedMoEPrepareAndFinalizeModular, fused_experts: FusedMoEExpertsModular, - shared_experts: torch.nn.Module | None = None, - moe_parallel_config: FusedMoEParallelConfig | None = None, + shared_experts: SharedExperts | None, inplace: bool = False, ): self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts - self.shared_experts = shared_experts - self.moe_parallel_config = moe_parallel_config + # Only accept shared experts if they can be run w/async. + # The MoERunner/SharedExperts class will coordinate with the MK to ensure + # that the SharedExperts are executed only once. + self.shared_experts = ( + shared_experts if prepare_finalize.supports_async() else None + ) self.inplace = inplace + moe_parallel_config = fused_experts.moe_config.moe_parallel_config + self.moe_parallel_config = moe_parallel_config self.is_dp_ep = ( moe_parallel_config is not None and moe_parallel_config.dp_size > 1 @@ -1081,6 +1090,17 @@ class FusedMoEKernelModularImpl: return workspace13, workspace2, fused_out + def _maybe_apply_shared_experts( + self, + shared_experts_input: torch.Tensor | None, + ): + if self.shared_experts is not None: + assert shared_experts_input is not None + self.shared_experts.apply( + shared_experts_input, + SharedExpertsOrder.MK_INTERNAL_OVERLAPPED, + ) + def _prepare( self, hidden_states: torch.Tensor, @@ -1253,15 +1273,6 @@ class FusedMoEKernelModularImpl: shared_experts_input is the original hidden_states (full dimension) needed by the shared expert MLP. """ - shared_output: torch.Tensor | None = None - - # For latent MoE: shared experts need the original hidden_states - # (full hidden_size), not the latent-projected version used by - # routed experts. - se_hidden_states = ( - shared_experts_input if shared_experts_input is not None else hidden_states - ) - if not self.prepare_finalize.supports_async(): assert not dbo_enabled() @@ -1273,8 +1284,6 @@ class FusedMoEKernelModularImpl: apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), ) - if self.shared_experts is not None: - shared_output = self.shared_experts(se_hidden_states) else: finalize_ret = self.prepare_finalize.finalize_async( output, @@ -1284,8 +1293,7 @@ class FusedMoEKernelModularImpl: apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), ) - if self.shared_experts is not None: - shared_output = self.shared_experts(se_hidden_states) + self._maybe_apply_shared_experts(shared_experts_input) # TODO(lucas): refactor this in the alternative schedules followup # currently unpack if we have hook + receiver pair or just @@ -1308,11 +1316,7 @@ class FusedMoEKernelModularImpl: receiver() - if self.shared_experts is None: - return output - else: - assert shared_output is not None - return shared_output, output + return output def apply( self, @@ -1326,7 +1330,7 @@ class FusedMoEKernelModularImpl: expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, shared_experts_input: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -1469,12 +1473,10 @@ class FusedMoEKernel: self, prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEExperts, - shared_experts: torch.nn.Module | None = None, - moe_parallel_config: FusedMoEParallelConfig | None = None, + shared_experts: SharedExperts | None = None, inplace: bool = False, ): super().__init__() - self.shared_experts = shared_experts # NOTE: check if we can remove # Initialize the implementation (monolithic or modular). self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl @@ -1485,14 +1487,12 @@ class FusedMoEKernel: prepare_finalize, fused_experts, shared_experts, - moe_parallel_config, inplace, ) elif isinstance( prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic ) and isinstance(fused_experts, FusedMoEExpertsMonolithic): - assert shared_experts is None assert not inplace self.impl = FusedMoEKernelMonolithicImpl( prepare_finalize, @@ -1508,6 +1508,13 @@ class FusedMoEKernel: self._post_init_setup() + @property + def owns_shared_experts(self) -> bool: + if isinstance(self.impl, FusedMoEKernelModularImpl): + return self.impl.shared_experts is not None + else: + return False + @property def is_monolithic(self) -> bool: return isinstance(self.impl, FusedMoEKernelMonolithicImpl) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index a63c02663..3d9a49902 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, @@ -545,7 +548,7 @@ def make_fp8_moe_kernel( experts_cls: type[mk.FusedMoEExperts], fp8_backend: Fp8MoeBackend, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - shared_experts: torch.nn.Module | None = None, + shared_experts: SharedExperts | None = None, ) -> mk.FusedMoEKernel: # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( @@ -581,12 +584,7 @@ def make_fp8_moe_kernel( kernel = mk.FusedMoEKernel( prepare_finalize, experts, - shared_experts=( - shared_experts - if moe_config.moe_parallel_config.use_deepep_ll_kernels - else None - ), - moe_parallel_config=moe_config.moe_parallel_config, + shared_experts=shared_experts, inplace=( not moe_config.disable_inplace and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 9008bdeec..7249b425f 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -859,7 +859,6 @@ def make_mxfp4_moe_kernel( if moe_config.moe_parallel_config.use_deepep_ll_kernels else None ), - moe_parallel_config=moe_config.moe_parallel_config, inplace=( not moe_config.disable_inplace and mxfp4_backend not in TRTLLM_BACKENDS ), diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 35451e87d..d946c5eb5 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import ( nvfp4_moe_quant_config, nvfp4_w4a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, +) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( prepare_nvfp4_moe_layer_for_fi_or_cutlass, ) @@ -386,7 +389,7 @@ def make_nvfp4_moe_kernel( moe_config: FusedMoEConfig, experts_cls: type[mk.FusedMoEExperts], routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - shared_experts: torch.nn.Module | None = None, + shared_experts: SharedExperts | None = None, ) -> mk.FusedMoEKernel: # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( @@ -422,12 +425,7 @@ def make_nvfp4_moe_kernel( kernel = mk.FusedMoEKernel( prepare_finalize, experts, - shared_experts=( - shared_experts - if moe_config.moe_parallel_config.use_deepep_ll_kernels - else None - ), - moe_parallel_config=moe_config.moe_parallel_config, + shared_experts=shared_experts, inplace=False, ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 7f26d0b86..33bf7a0c7 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, convert_moe_weights_to_flashinfer_trtllm_block_layout, @@ -321,7 +324,7 @@ def make_unquantized_moe_kernel( backend: UnquantizedMoeBackend, experts_cls: type[mk.FusedMoEExperts], routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - shared_experts: torch.nn.Module | None = None, + shared_experts: SharedExperts | None = None, ) -> mk.FusedMoEKernel: # Create Prepare/Finalize is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic) @@ -355,12 +358,7 @@ def make_unquantized_moe_kernel( kernel = mk.FusedMoEKernel( prepare_finalize, experts, - shared_experts=( - shared_experts - if moe_config.moe_parallel_config.use_deepep_ll_kernels - else None - ), - moe_parallel_config=moe_config.moe_parallel_config, + shared_experts=shared_experts, inplace=(not moe_config.disable_inplace and not is_monolithic), ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/deepep_ll.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/deepep_ll.py index a3266f5e8..0c6e32ae4 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/deepep_ll.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/deepep_ll.py @@ -325,7 +325,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): **(dict(use_nvfp4=True) if use_nvfp4 else dict()), **( dict(x_global_scale=qc_a1_gscale_or_scale) - if qc_a1_gscale_or_scale is not None + if qc_a1_gscale_or_scale is not None and nvfp4_dispatch else dict() ), async_finish=False, diff --git a/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py index c322a8cd4..d7aed4fde 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable import torch @@ -13,6 +14,13 @@ class FusedMoERouter(ABC): method that is used for routing hidden states based on router logits. """ + @abstractmethod + def set_capture_fn( + self, + capture_fn: Callable[[torch.Tensor], None] | None, + ) -> None: + raise NotImplementedError + @property @abstractmethod def routing_method_type(self) -> RoutingMethodType: diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index a09273fc8..4f9409e2c 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING import torch import torch.nn.functional as F -import vllm.envs as envs from vllm.distributed import ( get_ep_group, get_pcp_group, @@ -29,13 +28,15 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, + SharedExpertsOrder, +) from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import ( HAS_OPAQUE_TYPE, ModuleName, - aux_stream, - current_stream, direct_register_custom_op, ) from vllm.v1.worker.ubatching import dbo_current_ubatch_id @@ -74,6 +75,9 @@ def _resolve_layer_name(layer_name: str | ModuleName) -> str: return layer_name.value if isinstance(layer_name, ModuleName) else layer_name +# Note: _moe_forward and _moe_forward_shared should not contain any +# implementation details, They should merely pass along control to +# the runner's 'forward_dispatch' method. def _moe_forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -81,24 +85,12 @@ def _moe_forward( layer_name: _layer_name_type, ) -> torch.Tensor: layer = get_layer_from_name(_resolve_layer_name(layer_name)) - # TODO(bnell): this can be removed after MK migration is complete. - layer.ensure_moe_quant_config_init() - runner = layer.runner - with runner._sequence_parallel_context(): - if runner.use_dp_chunking: - return runner.forward_impl_chunked( - layer, - hidden_states, - router_logits, - shared_experts_input, - ) - else: - return runner.forward_impl( - layer, - hidden_states, - router_logits, - shared_experts_input, - ) + return layer.runner.forward_dispatch( + layer, + hidden_states, + router_logits, + shared_experts_input, + ) def _moe_forward_fake( @@ -117,24 +109,12 @@ def _moe_forward_shared( layer_name: _layer_name_type, ) -> tuple[torch.Tensor, torch.Tensor]: layer = get_layer_from_name(_resolve_layer_name(layer_name)) - # TODO(bnell): this can be removed after MK migration is complete. - layer.ensure_moe_quant_config_init() - runner = layer.runner - with runner._sequence_parallel_context(): - if runner.use_dp_chunking: - return runner.forward_impl_chunked( - layer, - hidden_states, - router_logits, - shared_experts_input, - ) - else: - return runner.forward_impl( - layer, - hidden_states, - router_logits, - shared_experts_input, - ) + return layer.runner.forward_dispatch( + layer, + hidden_states, + router_logits, + shared_experts_input, + ) def _moe_forward_shared_fake( @@ -159,7 +139,7 @@ def _moe_forward_shared_fake( direct_register_custom_op( op_name="moe_forward", op_func=_moe_forward, - mutates_args=["hidden_states"], + mutates_args=["hidden_states"], # is this still true? fake_impl=_moe_forward_fake, tags=(torch.Tag.needs_fixed_stride_order,), ) @@ -168,7 +148,6 @@ direct_register_custom_op( direct_register_custom_op( op_name="moe_forward_shared", op_func=_moe_forward_shared, - mutates_args=["hidden_states"], fake_impl=_moe_forward_shared_fake, tags=(torch.Tag.needs_fixed_stride_order,), ) @@ -213,87 +192,68 @@ class DefaultMoERunner(MoERunner): self.router = router self.routed_input_transform = routed_input_transform self.gate = gate - self.shared_experts = shared_experts self.quant_method = quant_method self.reduce_results = reduce_results self.enable_dbo = enable_dbo + self.shared_experts: SharedExperts | None = None + if shared_experts is not None: + self.shared_experts = SharedExperts( + shared_experts, + moe_config=moe_config, + # Note: For now we must pass quant_method along to SharedExperts so it + # can property determine where the shared experts are supposed to be + # called, i.e. by a MK or by the MoERunner. + # Once the MK can be created upfront, we can just pass in the proper + # flags derived from the quant_method's MK. + reduce_results=reduce_results, + quant_method=quant_method, + enable_dbo=enable_dbo, + ) + # Chunked all2all staging tensor - # TODO(bnell) rename these? + # These need to exist ahead of time due to CUDAgraph construction + # needing a fixed buffer address. + self.use_dp_chunking = self.moe_config.moe_parallel_config.use_dp_chunking self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None self._maybe_init_dp_chunking() - # Allow disabling of the separate shared experts stream for - # debug purposes. - # TODO: Remove this after more extensive testings with TP/DP - # and other execution modes - self.use_shared_experts_stream = False - if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM: - logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local") - self.shared_experts_stream = None - else: - # TODO(rob): enable shared expert overlap with non-cuda-alike. - # aux_stream() returns None on non-cuda-alike platforms. - self.shared_experts_stream = aux_stream() - if self.shared_experts_stream is not None: - logger.debug_once( - "Enabled separate cuda stream for MoE shared_experts", scope="local" - ) - # Needed for string -> FusedMoE layer lookup in custom ops. self.layer_name = layer.layer_name - self.moe_forward = self._select_forward(layer) + self.forward_entry, self.forward_impl = self._select_forward(layer) + + def _select_forward(self, layer: torch.nn.Module) -> tuple[Callable, Callable]: + # Select implementation based on presence of DP chunking. + forward_impl_fn = ( + self._forward_impl_chunked if self.use_dp_chunking else self._forward_impl + ) - def _select_forward(self, layer: torch.nn.Module) -> Callable: if current_platform.is_tpu() or current_platform.is_cpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. # Note: CPU doesn't require wrapped forward_impl. - return _moe_forward if self.shared_experts is None else _moe_forward_shared + return ( + _moe_forward if self.shared_experts is None else _moe_forward_shared, + forward_impl_fn, + ) return ( torch.ops.vllm.moe_forward if self.shared_experts is None - else torch.ops.vllm.moe_forward_shared + else torch.ops.vllm.moe_forward_shared, + forward_impl_fn, ) - @property - def use_dp_chunking(self) -> bool: - return ( - self.moe_config.moe_parallel_config.use_deepep_ll_kernels - or self.moe_config.moe_parallel_config.use_mori_kernels - or self.moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels - or self.moe_config.moe_parallel_config.use_nixl_ep_kernels - ) and envs.VLLM_ENABLE_MOE_DP_CHUNK + # TODO(bnell): temporary hack, do not call this method. + def _replace_quant_method(self, quant_method: FusedMoEMethodBase): + if self.shared_experts is not None: + self.shared_experts._quant_method = quant_method + self.quant_method = quant_method - def _maybe_setup_shared_experts_stream( - self, - hidden_states: torch.Tensor, - shared_input: torch.Tensor | None, - ): - if self.use_shared_experts_stream: - assert self.shared_experts_stream is not None - assert self.moe_config.disable_inplace - - shared_experts_input = ( - shared_input if shared_input is not None else hidden_states - ) - - # Record that the shared_experts_input will be used in the - # shared_experts_stream to avoid gc issue from - # deallocation. For more details: - # https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 - # NOTE: We don't need shared_output.record_stream(current_stream()) - # because we synch the streams before using shared_output. - shared_experts_input.record_stream(self.shared_experts_stream) - - # Mark sync start point for the separate shared experts - # stream here since we want to run in parallel with the - # router/gate (next op below) - assert self.shared_experts_stream is not None - self.shared_experts_stream.wait_stream(current_stream()) + def is_internal_router(self) -> bool: + return self.gate is not None def _maybe_init_dp_chunking(self): if not self.use_dp_chunking: @@ -325,38 +285,6 @@ class DefaultMoERunner(MoERunner): device=device, ) - @property - def has_separate_shared_experts(self) -> bool: - return ( - not self.quant_method.mk_owns_shared_expert - and self.shared_experts is not None - ) - - def _apply_shared_experts( - self, - hidden_states: torch.Tensor, - allow_streaming: bool = False, - ) -> torch.Tensor | None: - shared_output: torch.Tensor | None = None - if self.has_separate_shared_experts: - assert self.shared_experts is not None - - if self.use_shared_experts_stream and allow_streaming: - # Run shared experts in parallel on a separate stream - # NOTE: We start the separate stream here and mark the - # sync end point immediately after it is done. This is - # important to avoid excessive stream allocations by the cuda - # graph replay later. - with torch.cuda.stream(self.shared_experts_stream): - # Note that hidden_states clone() is necessary here to avoid - # conflict with the main stream - shared_output = self.shared_experts(hidden_states) - current_stream().wait_stream(self.shared_experts_stream) - else: - shared_output = self.shared_experts(hidden_states) - - return shared_output - def must_reduce_shared_expert_outputs(self) -> bool: """ The shared_experts are typically computed using the RowParallelLinear @@ -384,7 +312,9 @@ class DefaultMoERunner(MoERunner): else: return tensor_model_parallel_all_reduce(final_hidden_states) - def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor: + def apply_routed_input_transform( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None]: """Apply transform for routed experts (e.g., latent projection). This is called by FusedMoE.forward_native. The original hidden_states @@ -394,15 +324,22 @@ class DefaultMoERunner(MoERunner): TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be moved inside SharedFusedMoE to all-reduce on the smaller latent dimension. + + Returns (possibly transformed) hidden states and the input for shared + experts (or None if there are no shared experts). """ if self.routed_input_transform is not None: result = self.routed_input_transform(hidden_states) # ReplicatedLinear returns (output, extra_bias) tuple. # We only need the output tensor; extra_bias is not used here. if isinstance(result, tuple): - return result[0] - return result - return hidden_states + return result[0], hidden_states + return result, hidden_states + + return ( + hidden_states, + hidden_states if self.shared_experts is not None else None, + ) def _maybe_reduce_output( self, @@ -446,13 +383,11 @@ class DefaultMoERunner(MoERunner): def _maybe_pad_hidden_states( self, - original_hidden_states: torch.Tensor | None, + shared_experts_input: torch.Tensor | None, hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, list[int]]: - original_hidden_dim = ( - original_hidden_states.shape[-1] - if original_hidden_states is not None - else 0 + shared_experts_hidden_dim = ( + shared_experts_input.shape[-1] if shared_experts_input is not None else 0 ) transformed_hidden_dim = hidden_states.shape[-1] if ( @@ -467,29 +402,37 @@ class DefaultMoERunner(MoERunner): ) if self.shared_experts is not None: - orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim] + orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim] else: orig_hidden_dims = [transformed_hidden_dim] return hidden_states, orig_hidden_dims + def _maybe_apply_shared_experts( + self, + shared_experts_input: torch.Tensor | None, + order: SharedExpertsOrder, + ): + if self.shared_experts is not None: + assert shared_experts_input is not None + self.shared_experts.apply(shared_experts_input, order) + def _apply_quant_method( self, layer: torch.nn.Module, hidden_states: torch.Tensor, router_logits: torch.Tensor, - shared_input: torch.Tensor | None, - run_shared_experts_before: bool = True, + shared_experts_input: torch.Tensor | None, ) -> tuple[torch.Tensor | None, torch.Tensor]: - shared_input = shared_input if shared_input is not None else hidden_states - shared_output: torch.Tensor | None = None - # Run this before quant_method to avoid inplace issues. - if run_shared_experts_before: - shared_output = self._apply_shared_experts(shared_input, False) + # TODO(bnell): probably not needed anymore since inplace is + # disabled when shared experts are present. + self._maybe_apply_shared_experts( + shared_experts_input, SharedExpertsOrder.NO_OVERLAP + ) if self.quant_method.is_monolithic: - result = self.quant_method.apply_monolithic( + fused_out = self.quant_method.apply_monolithic( layer=layer, x=hidden_states, router_logits=router_logits, @@ -500,25 +443,25 @@ class DefaultMoERunner(MoERunner): router_logits=router_logits, ) - result = self.quant_method.apply( + # Passing shared_experts_input in case SharedExpertsOrder is + # NO_OVERLAP or MK_INTERNAL_OVERLAPPED. + fused_out = self.quant_method.apply( layer=layer, x=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - shared_experts_input=shared_input, + shared_experts_input=shared_experts_input, ) - if isinstance(result, tuple): - assert shared_output is None - shared_output, hidden_states = result - else: - hidden_states = result + self._maybe_apply_shared_experts( + shared_experts_input, + SharedExpertsOrder.MULTI_STREAM_OVERLAPPED, + ) - if not run_shared_experts_before and self.has_separate_shared_experts: - assert shared_output is None - shared_output = self._apply_shared_experts(shared_input, True) - - return shared_output, hidden_states + return ( + self.shared_experts.output if self.shared_experts is not None else None, + fused_out, + ) def _sequence_parallel_context(self): ctx = get_forward_context() @@ -558,18 +501,16 @@ class DefaultMoERunner(MoERunner): return final_shared_hidden_states, final_fused_hidden_states - def _maybe_gate( + def _maybe_sync_shared_experts_stream( self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: + shared_experts_input: torch.Tensor | None, + ): # If router/gate provided, then apply it here. # (Note: This code runs only when "overlapped mode" is on to allow # parallel execution of shared experts with the FusedMoE via # separate cuda stream) - if self.gate is not None: - router_logits, _ = self.gate(hidden_states) - return router_logits + if self.shared_experts is not None: + self.shared_experts.maybe_sync_shared_experts_stream(shared_experts_input) @property def do_naive_dispatch_combine(self) -> bool: @@ -624,7 +565,6 @@ class DefaultMoERunner(MoERunner): hidden_states, dim=0, ) - # need RS for shared_output? if self.shared_experts is not None: assert shared_output is not None @@ -637,30 +577,86 @@ class DefaultMoERunner(MoERunner): hidden_states: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # For latent MoE: save ORIGINAL hidden_states before transform - # (shared_experts need original dimension, routed experts use transformed) - if self.shared_experts is not None: - original_hidden_states = hidden_states - else: - original_hidden_states = None + """Invoke the fused moe layer. + + Input: + - hidden_states + - router_logits + + Output: + - The new hidden_states. + or + - A tuple of (shared experts output, new hidden_states). + + Calling sequence + - forward + - self.forward_entry (_moe_forward or _moe_forward_shared custom op) + - forward_dispatch + - forward_impl (_forward_impl or _forward_impl_chunked) + + Note: The existence of _moe_forward and _moe_forward_shared custom ops are due + to the following reasons: + 1. the chunking loop in _forward_impl_chunked cannot be compiled by + torch.compile + 2. pytorch cannot handle union types in custom op signatures so _moe_forward + and _moe_forward_shared must be split. + + If _forward_impl_chunked can be implemented via torch.scan we can potentially + get rid of _moe_forward and _moe_forward_shared and collapse the whole sequence + into the 'forward' method. + """ # Apply transform for routed experts (e.g., latent projection for latent MoE) - hidden_states = self.apply_routed_input_transform(hidden_states) + hidden_states, shared_experts_input = self.apply_routed_input_transform( + hidden_states + ) hidden_states, og_hidden_dims = self._maybe_pad_hidden_states( - original_hidden_states, + shared_experts_input, hidden_states, ) - fused_output = self.moe_forward( + fused_output = self.forward_entry( hidden_states, router_logits, - original_hidden_states, + shared_experts_input, self._encode_layer_name(), ) return self._maybe_reduce_output(fused_output, og_hidden_dims) + def forward_dispatch( + self, + layer: torch.nn.Module, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # TODO(bnell): this can be removed after MK migration is complete. + layer.ensure_moe_quant_config_init() + + # Sync aux and main stream for shared expert multi-stream overlap. + self._maybe_sync_shared_experts_stream(shared_experts_input) + + # If the Runner holds the gate, apply it after the stream sync, + # so it can run overlapped with the + # NOTE: in future PR, MoE runner will always hold the gate. + if self.gate is not None: + router_logits, _ = self.gate(hidden_states) + + self._maybe_apply_shared_experts( + shared_experts_input, + SharedExpertsOrder.EXTERNAL, + ) + + with self._sequence_parallel_context(): + return self.forward_impl( + layer, + hidden_states, + router_logits, + shared_experts_input, + ) + def _slice_and_copy_input( self, out_slice: torch.Tensor, @@ -681,17 +677,13 @@ class DefaultMoERunner(MoERunner): out_slice.copy_(orig_slice, non_blocking=True) return out_slice - def forward_impl_chunked( + def _forward_impl_chunked( self, layer: torch.nn.Module, hidden_states: torch.Tensor, router_logits: torch.Tensor, - shared_input: torch.Tensor | None, + shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # Gate overlap not supported when chunking is enabled. Run the - # gate first. - router_logits = self._maybe_gate(hidden_states, router_logits) - final_shared_hidden_states, final_fused_hidden_states = ( self._allocate_dp_chunking_outputs(hidden_states, router_logits) ) @@ -737,9 +729,9 @@ class DefaultMoERunner(MoERunner): chunk_end, ) - shared_input_chunk = ( - shared_input[chunk_start:chunk_end, :] - if shared_input is not None + shared_experts_input_chunk = ( + shared_experts_input[chunk_start:chunk_end, :] + if shared_experts_input is not None else None ) @@ -747,7 +739,7 @@ class DefaultMoERunner(MoERunner): layer=layer, hidden_states=hidden_states_chunk, router_logits=router_logits_chunk, - shared_input=shared_input_chunk, + shared_experts_input=shared_experts_input_chunk, ) # Store outputs @@ -769,40 +761,13 @@ class DefaultMoERunner(MoERunner): assert final_shared_hidden_states is not None return (final_shared_hidden_states, final_fused_hidden_states) - def forward_impl( + def _forward_impl( self, layer: torch.nn.Module, hidden_states: torch.Tensor, router_logits: torch.Tensor, - shared_input: torch.Tensor | None, + shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - self.use_shared_experts_stream = ( - current_platform.is_cuda() - and self.has_separate_shared_experts - and not self.use_dp_chunking - and self.shared_experts_stream is not None - and ( - hidden_states.shape[0] - <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD - ) - ) - - # Check if we need to run shared experts before matrix multiply because - # matrix multiply may modify the hidden_states. - run_shared_experts_before = ( - self.has_separate_shared_experts and not self.use_shared_experts_stream - ) - - # The shared experts stream must be set up before calling the gate so they - # can be overlapped. - if not run_shared_experts_before: - self._maybe_setup_shared_experts_stream( - hidden_states, - shared_input, - ) - - router_logits = self._maybe_gate(hidden_states, router_logits) - # TODO(bnell): parts of the dispatch/combine steps will go away once # #32567 lands and the remaining kernels are made MKs. The PCP # code will probably remain @@ -816,8 +781,7 @@ class DefaultMoERunner(MoERunner): layer=layer, hidden_states=hidden_states, router_logits=router_logits, - shared_input=shared_input, - run_shared_experts_before=run_shared_experts_before, + shared_experts_input=shared_experts_input, ) return self._maybe_combine( diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index b298cc2d0..720e997cd 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -32,3 +32,7 @@ class MoERunner(ABC): final_hidden_states: torch.Tensor, ): raise NotImplementedError + + @abstractmethod + def is_internal_router(self) -> bool: + raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/runner/shared_experts.py b/vllm/model_executor/layers/fused_moe/runner/shared_experts.py new file mode 100644 index 000000000..6d2189cb4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/runner/shared_experts.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import IntEnum + +import torch + +import vllm.envs as envs +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, +) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase, +) +from vllm.platforms import current_platform +from vllm.utils.torch_utils import ( + aux_stream, + current_stream, +) +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, +) + +logger = init_logger(__name__) + + +class SharedExpertsOrder(IntEnum): + # No shared experts. + NONE = (0,) + + # Get rid of this one? combine with BEFORE? + # Note: this might be important for torch.compile reasons. Can + # get rid of it after _moe_forward is undone. + EXTERNAL = (1,) + + # No overlap - defensively called before MK. + NO_OVERLAP = (2,) + + # Overlapped with dispatch/combine in DP/EP - called by the MK. + MK_INTERNAL_OVERLAPPED = (3,) + + # Overlapped with the gate, router, experts in aux stream. + MULTI_STREAM_OVERLAPPED = (4,) + + +class SharedExperts: + def __init__( + self, + layer: torch.nn.Module, + moe_config: FusedMoEConfig, + quant_method: QuantizeMethodBase, + reduce_results: bool, + enable_dbo: bool, + ): + from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, + ) + + # quant_method must be a FusedMoEMethodBase but we can't use the type + # due to circular imports. + assert isinstance(quant_method, FusedMoEMethodBase) + + # The SharedExperts need to handle DBO since they can be called from + # an MK's finalize method. We keep a list of outputs indexed by current + # DBO ubatch id to handle this case. If DBO is not enabled, the + # index is always 0 and the second output list element is ignored. + self.enable_dbo = enable_dbo + self._output: list[torch.Tensor | None] = [None, None] + self._layer = layer + self._moe_config = moe_config + self._quant_method = quant_method + self._reduce_results = reduce_results + self._use_dp_chunking = moe_config.moe_parallel_config.use_dp_chunking + + # Allow disabling of the separate shared experts stream for + # debug purposes. + # TODO: Remove this after more extensive testings with TP/DP + # and other execution modes + if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM: + logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local") + self._stream = None + else: + # TODO(rob): enable shared expert overlap with non-cuda-alike. + # aux_stream() returns None on non-cuda-alike platforms. + self._stream = aux_stream() + if self._stream is not None: + logger.debug_once( + "Enabled separate cuda stream for MoE shared_experts", scope="local" + ) + + @property + def _has_external_experts(self) -> bool: + # Disable shared expert overlap if: + # - we are using eplb with non-default backend, because of correctness issues + # - we are using flashinfer with DP, since there nothing to gain + backend = self._moe_config.moe_parallel_config.all2all_backend + return not ( + ( + self._moe_config.moe_parallel_config.enable_eplb + and backend != "allgather_reducescatter" + ) + or self._moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels + ) + + def _determine_shared_experts_order( + self, + hidden_states: torch.Tensor, + ) -> SharedExpertsOrder: + if self._has_external_experts and not self._use_dp_chunking: + return SharedExpertsOrder.EXTERNAL + + if self._quant_method.mk_owns_shared_expert: + return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED + + should_run_shared_in_aux_stream = ( + current_platform.is_cuda() + and not self._use_dp_chunking + and self._stream is not None + and hidden_states.shape[0] + <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD + ) + + if should_run_shared_in_aux_stream: + return SharedExpertsOrder.MULTI_STREAM_OVERLAPPED + else: + return SharedExpertsOrder.NO_OVERLAP + + def maybe_sync_shared_experts_stream( + self, + shared_experts_input: torch.Tensor, + ): + experts_order = self._determine_shared_experts_order(shared_experts_input) + + if experts_order == SharedExpertsOrder.MULTI_STREAM_OVERLAPPED: + assert self._stream is not None + assert self._moe_config.disable_inplace + + # Record that the clone will be used by shared_experts_stream + # to avoid gc issue from deallocation of hidden_states_clone + # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 + # NOTE: We don't need shared_output.record_stream(current_stream()) + # because we synch the streams before using shared_output. + shared_experts_input.record_stream(self._stream) + + # Mark sync start point for the aux stream since we will + # run in parallel with router/gate. + self._stream.wait_stream(current_stream()) + + def _run_in_aux_stream( + self, + shared_experts_input: torch.Tensor, + ) -> torch.Tensor: + # TODO: assert that maybe_sync_shared_experts_stream has been called. + + # Run shared experts in parallel on a separate stream. + with torch.cuda.stream(self._stream): + output = self._layer(shared_experts_input) + current_stream().wait_stream(self._stream) + + return output + + def _maybe_reduce_shared_out(self, shared_out: torch.Tensor) -> torch.Tensor: + # Reduce shared expert outputs if necessary, since the MLP + # should have been created with reduce_results=False. + if ( + self._reduce_results + and self._quant_method.moe_kernel is not None + and self._quant_method.moe_kernel.output_is_reduced() + and get_tensor_model_parallel_world_size() > 1 + ): + shared_out = tensor_model_parallel_all_reduce(shared_out) + return shared_out + + @property + def _output_idx(self) -> int: + return dbo_current_ubatch_id() if self.enable_dbo else 0 + + @property + def output(self) -> torch.Tensor: + assert self._output[self._output_idx] is not None + output = self._output[self._output_idx] + self._output[self._output_idx] = None + return output + + def apply( + self, + shared_experts_input: torch.Tensor, + order: SharedExpertsOrder, + ): + experts_order = self._determine_shared_experts_order(shared_experts_input) + + if order != experts_order: + return None + + assert self._output[self._output_idx] is None + + if order == SharedExpertsOrder.MULTI_STREAM_OVERLAPPED: + self._output[self._output_idx] = self._run_in_aux_stream( + shared_experts_input + ) + else: + self._output[self._output_idx] = self._layer(shared_experts_input) + + if order == SharedExpertsOrder.EXTERNAL: + # TODO: figure out how to combine this with maybe_reduce_output? + # or get rid of it completely. + assert self._output[self._output_idx] is not None + self._output[self._output_idx] = self._maybe_reduce_shared_out( + self._output[self._output_idx] + ) + + assert self._output[self._output_idx] is not None diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 37336df17..ed243e992 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -3,14 +3,10 @@ import torch -from vllm.distributed import ( - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) from vllm.model_executor.layers.fused_moe.layer import FusedMoE -# TODO(bnell): Add shared + fused combo function? e.g. + +# TODO(bnell): Remove this entirely class SharedFusedMoE(FusedMoE): """ A FusedMoE operation that also computes the results of shared experts. @@ -23,36 +19,11 @@ class SharedFusedMoE(FusedMoE): hidden_states: torch.Tensor, router_logits: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - if not self.use_overlapped: - if self._shared_experts is not None: - shared_out = self._shared_experts(hidden_states) - - # Reduce shared expert outputs if necessary, since the MLP - # should have been created with reduce_results=False. - if ( - self.reduce_results - and get_tensor_model_parallel_world_size() > 1 - and self.must_reduce_shared_expert_outputs() - ): - shared_out = tensor_model_parallel_all_reduce(shared_out) - else: - shared_out = None - - fused_out = super().forward( - hidden_states=hidden_states, - router_logits=router_logits, - ) + result = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + if self.shared_experts is None: + return None, result else: - shared_out, fused_out = super().forward( - hidden_states=hidden_states, - router_logits=router_logits, - ) - # ensure early TP reduction of shared expert outputs when required - if ( - shared_out is not None - and self.reduce_results - and get_tensor_model_parallel_world_size() > 1 - and self.must_reduce_shared_expert_outputs() - ): - shared_out = tensor_model_parallel_all_reduce(shared_out) - return shared_out, fused_out + return result diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 2b149a553..f0d799824 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -245,7 +245,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: return self.forward( layer=layer, x=x, @@ -261,7 +261,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( hidden_states=x, @@ -283,7 +283,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: return self.forward_native( layer, x, topk_weights, topk_ids, shared_experts_input ) @@ -293,7 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.is_monolithic if self.unquantized_backend == UnquantizedMoeBackend.CPU: assert self.moe_kernel is None diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 03dfaa794..eff571ef2 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -811,7 +811,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: return fused_marlin_moe( x, layer.w13_qweight, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 716a20090..729924663 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -483,7 +483,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts # TODO(bnell): Do these need to be called on the hot path? diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1b8b726d9..bce63bcbe 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -355,7 +355,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( x, @@ -603,7 +603,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( @@ -628,7 +628,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply( x, @@ -963,7 +963,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( x, @@ -987,7 +987,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply( @@ -1127,7 +1127,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts return fused_experts( @@ -1611,7 +1611,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.kernel_backend == "Flashinfer" return flashinfer_trtllm_mxint4_moe( x=x, @@ -1638,7 +1638,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.kernel_backend == "Marlin" return fused_marlin_moe( x, @@ -1887,7 +1887,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts return fused_experts( @@ -2502,7 +2502,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: if layer.enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index d971f3b5b..301441ff0 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -141,7 +141,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts return fused_experts( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fffcfa5e6..965e1af72 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -877,7 +877,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( @@ -902,7 +902,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply( diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 145610e9c..2a72da26c 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -650,7 +650,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: if layer.apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for" diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 8e367c883..ce0dc0f4e 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -907,7 +907,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: return fused_marlin_moe( x, layer.w13_qweight, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index eb9591936..b0562ee43 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -935,7 +935,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( @@ -960,7 +960,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply( @@ -1419,7 +1419,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( @@ -1444,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply( diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index f5c679840..a327ac17b 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -369,7 +369,7 @@ class MoeWNA16Method(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts assert layer.activation == MoEActivation.SILU, ( diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index c69e99a68..adb191b0a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -377,7 +377,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply( @@ -398,7 +398,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index a58ee5c44..c48e49fe8 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -444,7 +444,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, @@ -634,7 +634,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, ) @@ -1027,7 +1027,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: if not self.emulate: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py index 5f8352fae..f65a197ab 100644 --- a/vllm/model_executor/models/transformers/moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -94,6 +94,8 @@ def transformers_moe_forward( self = forward_context.no_compile_layers[layer_name] self._topk_ids = topk_ids # Clone hidden_states because it will be mutated in-place in FusedMoE + # TODO(bnell): figure out a way to avoid calling runner directly. + # it is a hack that the weight are being passed via logits. return self.runner.forward(hidden_states.clone(), topk_weights)