diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 58a4d09d9..85ad22137 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -957,18 +957,18 @@ class MarlinMoEWeightData: ) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( - a_type, - b_type, - c_type, - group_blocks, - m, - n, - k, - e, - topk, - ep_size, - act_order, - is_k_full, + a_type: ScalarType, + b_type: ScalarType, + c_type: ScalarType, + group_blocks: int, + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + act_order: bool, + is_k_full: bool, ): torch.cuda.manual_seed(1) group_size = group_blocks if group_blocks <= 0 else group_blocks * 16 @@ -1044,7 +1044,6 @@ def test_fused_marlin_moe( None, w1_data.scales, w2_data.scales, - score, topk_weights, topk_ids, global_num_experts=e, @@ -1120,7 +1119,6 @@ def test_fused_marlin_moe_with_bias(m): w2_data.marlin_bias, w1_data.scales, w2_data.scales, - score, topk_weights, topk_ids, global_num_experts=e, @@ -1199,7 +1197,6 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int): None, # bias2 w1_data.scales, w2_data.scales, - score, topk_weights, topk_ids, global_num_experts=e, @@ -1519,7 +1516,6 @@ def test_batched_fused_marlin_moe( "bias2": None, "w1_scale": w1_data.scales, "w2_scale": w2_data.scales, - "gating_output": score, "global_num_experts": e, "expert_map": None, "global_scale1": w1_data.global_scale, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 603c8fc96..ce77eafa6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -210,7 +210,6 @@ def fused_marlin_moe( bias2: torch.Tensor | None, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - gating_output: torch.Tensor | None, topk_weights: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, @@ -250,8 +249,6 @@ def fused_marlin_moe( - w2 (torch.Tensor): The second set of expert weights. - w1_scale (torch.Tensor): Scale to be used for w1. - w2_scale (torch.Tensor): Scale to be used for w2. - - gating_output (torch.Tensor|None): The output of the gating - operation (before softmax). - g_idx1 (torch.Tensor|None): The first set of act_order indices. - g_idx2 (torch.Tensor|None): The second set of act_order indices. - sort_indices1 (torch.Tensor|None): The first act_order input @@ -292,8 +289,6 @@ def fused_marlin_moe( topk = topk_ids.size(1) # Check constraints. - if gating_output is not None: - assert gating_output.size(0) == M, "Number of tokens mismatch" assert w1.size(1) * 16 == K, "Hidden size mismatch w1" assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" @@ -381,7 +376,6 @@ def batched_fused_marlin_moe( bias2: torch.Tensor | None, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - gating_output: torch.Tensor | None, quant_type_id: int, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, @@ -718,7 +712,6 @@ class MarlinExperts(MarlinExpertsBase): bias2=self.w2_bias, w1_scale=self.w1_scale, w2_scale=self.w2_scale, - gating_output=None, topk_weights=topk_weights, topk_ids=topk_ids, global_scale1=self.g1_alphas, @@ -833,7 +826,6 @@ class BatchedMarlinExperts(MarlinExpertsBase): bias2=self.w2_bias, w1_scale=self.w1_scale, w2_scale=self.w2_scale, - gating_output=None, quant_type_id=self.quant_type_id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 9d1ec7fd8..36562142c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -14,9 +14,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, ) -from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( - FusedMoERouter, -) from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) @@ -108,11 +105,24 @@ class FusedMoEMethodBase(QuantizeMethodBase): def method_name(self) -> str: return self.__class__.__name__ - @abstractmethod + @property + def is_monolithic(self) -> bool: + return False + + # @abstractmethod def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - router: FusedMoERouter, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + # @abstractmethod + def apply_monolithic( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 64de30ce6..53df0dc85 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -16,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel, FusedMoEPrepareAndFinalize, ) -from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( - FusedMoERouter, -) logger = init_logger(__name__) @@ -40,6 +37,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): not self.fused_experts.supports_expert_map(), ) self.old_quant_method = old_quant_method + assert not self.old_quant_method.is_monolithic logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) @staticmethod @@ -94,16 +92,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - - result = self.fused_experts( + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -115,5 +108,3 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=None if self.disable_expert_map else layer.expert_map, ) - - return result diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3b85bc7a2..daf2e0e6b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -522,8 +522,7 @@ class FusedMoE(CustomOp): self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation - # TODO(bnell): in next PR move capture back to layer - capture: Callable[[torch.Tensor], None] | None = None + self.capture: Callable[[torch.Tensor], None] | None = None if ( self.vllm_config.model_config is not None and self.vllm_config.model_config.enable_return_routed_experts @@ -531,7 +530,9 @@ class FusedMoE(CustomOp): # In dummy runs, the capturer is not initialized. capturer = RoutedExpertsCapturer.get_instance() if capturer is not None: - capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids) + self.capture = lambda topk_ids: capturer.capture( + self.layer_id, topk_ids + ) self.router = create_fused_moe_router( top_k=top_k, @@ -550,7 +551,6 @@ class FusedMoE(CustomOp): # TODO(bnell): once we can construct the MK at init time, we # can make this a value. indices_type_getter=lambda: self.quant_method.topk_indices_dtype, - capture=capture, ) self.routing_method_type: RoutingMethodType = self.router.routing_method_type @@ -1673,12 +1673,27 @@ class FusedMoE(CustomOp): staged_router_logits.copy_(router_logits, non_blocking=True) # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - router=self.router, - x=staged_hidden_states, - router_logits=staged_router_logits, - ) + if self.quant_method.is_monolithic: + final_hidden_states = self.quant_method.apply_monolithic( + layer=self, + x=staged_hidden_states, + router_logits=staged_router_logits, + ) + else: + topk_weights, topk_ids = self.router.select_experts( + hidden_states=staged_hidden_states, + router_logits=staged_router_logits, + ) + + if self.capture is not None: + self.capture(topk_ids) + + final_hidden_states = self.quant_method.apply( + layer=self, + x=staged_hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) @@ -1810,15 +1825,20 @@ class FusedMoE(CustomOp): extra_tensors=extra_tensors, ) if extra_tensors is not None: - hidden_states_combined, router_logits, extra_tensors_combined = ( - dispatch_res - ) + ( + orig_hidden_states, + router_logits, + extra_tensors_combined, + ) = dispatch_res hidden_states_combined = ( - hidden_states_combined, + orig_hidden_states, extra_tensors_combined[0], ) else: hidden_states_combined, router_logits = dispatch_res + orig_hidden_states = hidden_states_combined + else: + orig_hidden_states = hidden_states # Run shared experts before matrix multiply. # because matrix multiply maybe modify the hidden_states. @@ -1840,14 +1860,33 @@ class FusedMoE(CustomOp): ) # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - router=self.router, - x=hidden_states_combined - if do_naive_dispatch_combine - else hidden_states, - router_logits=router_logits, - ) + x = hidden_states_combined if do_naive_dispatch_combine else hidden_states + + # TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014). + # Figure out nicer way to do this. + x_orig = orig_hidden_states if do_naive_dispatch_combine else hidden_states + + if self.quant_method.is_monolithic: + final_hidden_states = self.quant_method.apply_monolithic( + layer=self, + x=x, + router_logits=router_logits, + ) + else: + topk_weights, topk_ids = self.router.select_experts( + hidden_states=x_orig, + router_logits=router_logits, + ) + + if self.capture is not None: + self.capture(topk_ids) + + final_hidden_states = self.quant_method.apply( + layer=self, + x=x, # The type signture of this is wrong due to the hack. + topk_weights=topk_weights, + topk_ids=topk_ids, + ) if has_separate_shared_experts: assert self.shared_experts is not None diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 683f7188c..9969818ab 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -127,7 +127,6 @@ class BaseRouter(FusedMoERouter): self.eplb_state = eplb_state self.enable_eplb = enable_eplb self.indices_type_getter = indices_type_getter - self.capture: Callable[[torch.tensor], None] | None = None def _validate_eplb_state(self) -> None: """Validate that EPLB state is properly initialized if EPLB is enabled.""" @@ -238,8 +237,4 @@ class BaseRouter(FusedMoERouter): # Step 5: Convert indices dtype topk_ids = self._convert_indices_dtype(topk_ids, indices_type) - # TODO(bnell): temporary hack until select_experts is moved into FusedMoE - if self.capture is not None: - self.capture(topk_ids) - return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index 17c59352f..0367189ca 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -55,4 +55,6 @@ class CustomRoutingRouter(BaseRouter): renormalize=self.renormalize, ) - return topk_weights.to(torch.float32), topk_ids + return topk_weights.to(torch.float32), topk_ids.to( + torch.int32 if indices_type is None else indices_type + ) diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 44b586650..7c230686f 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -124,7 +124,9 @@ def fused_topk_bias( topk_weights = scores.gather(1, topk_indices) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights.to(torch.float32), topk_indices.to(torch.int32) + return topk_weights.to(torch.float32), topk_indices.to( + torch.int32 if indices_type is None else indices_type + ) class FusedTopKBiasRouter(BaseRouter): @@ -176,6 +178,7 @@ class FusedTopKBiasRouter(BaseRouter): topk=self.top_k, renormalize=self.renormalize, scoring_func=self.scoring_func, + indices_type=indices_type, ) if self.routed_scaling_factor != 1.0: diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index d28f07558..890f846d3 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -6,7 +6,6 @@ import torch import vllm.envs as envs from vllm.distributed.eplb.eplb_state import EplbLayerState -from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( CustomRoutingRouter, ) @@ -49,7 +48,6 @@ def create_fused_moe_router( # eplb parameters enable_eplb: bool = False, eplb_state: EplbLayerState = EMPTY_EPLB_STATE, - capture: Callable[[torch.tensor], None] | None = None, ) -> FusedMoERouter: """ Factory function to create the appropriate FusedMoERouter subclass based on @@ -90,21 +88,16 @@ def create_fused_moe_router( Returns: An instance of the appropriate FusedMoERouter subclass """ - router: BaseRouter routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY if routing_strategy != "": - router = RoutingSimulatorRouter( + return RoutingSimulatorRouter( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) - # TODO(bnell): this is temporary until select_experts is - # separated from apply. - router.capture = capture - return router if use_grouped_topk: assert custom_routing_function is None @@ -113,7 +106,7 @@ def create_fused_moe_router( "num_expert_group and topk_group must be provided when " "use_grouped_topk is True" ) - router = GroupedTopKRouter( + return GroupedTopKRouter( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, @@ -127,11 +120,9 @@ def create_fused_moe_router( enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) - router.capture = capture - return router if custom_routing_function is not None: - router = CustomRoutingRouter( + return CustomRoutingRouter( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, @@ -140,11 +131,9 @@ def create_fused_moe_router( enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) - router.capture = capture - return router if e_score_correction_bias is not None: - router = FusedTopKBiasRouter( + return FusedTopKBiasRouter( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, @@ -155,10 +144,8 @@ def create_fused_moe_router( enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) - router.capture = capture - return router - router = FusedTopKRouter( + return FusedTopKRouter( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, @@ -167,5 +154,3 @@ def create_fused_moe_router( enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, ) - router.capture = capture - return router diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 2581a1e56..4b85cc5c2 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable import torch import torch.nn.functional as F @@ -31,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( make_unquantized_moe_kernel, select_unquantized_moe_backend, ) -from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( - FusedMoERouter, -) from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -66,6 +64,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul ) self.kernel: mk.FusedMoEModularKernel | None = None + self._is_monolithic = current_platform.is_cpu() or current_platform.is_xpu() + + @property + def is_monolithic(self) -> bool: + return self._is_monolithic @property def supports_eplb(self) -> bool: @@ -212,7 +215,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): import intel_extension_for_pytorch as ipex ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + self.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.w13_weight, layer.w2_weight, use_prepack=True, @@ -244,11 +247,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) assert packed_w2_weight.size() == layer.w2_weight.size() layer.w2_weight.copy_(packed_w2_weight) - layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) + self.cpu_fused_moe: Callable = cpu_fused_moe.SGLFusedMOE(layer) else: - layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) + self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) else: - layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) + self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) elif current_platform.is_cuda_alike(): self._setup_kernel( layer=layer, @@ -259,15 +262,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: return self.forward( - router=router, layer=layer, x=x, - router_logits=router_logits, + topk_weights=topk_weights, + topk_ids=topk_ids, ) def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: @@ -282,18 +285,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def forward_cuda( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.kernel - - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - - result = self.kernel( + assert self.kernel is not None + return self.kernel( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -306,24 +303,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=layer.expert_map, ) - return result - - def forward_cpu( + def forward_monolithic_cpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if ( - layer.enable_eplb is not False - or layer.eplb_state.expert_load_view is not None - or layer.eplb_state.logical_to_physical_map is not None - or layer.eplb_state.logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for CPU.") - - return layer.cpu_fused_moe( + return self.cpu_fused_moe( layer, x, layer.use_grouped_topk, @@ -342,21 +328,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.activation, ) - def forward_xpu( + def forward_monolithic_xpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if ( - layer.enable_eplb is not False - or layer.eplb_state.expert_load_view is not None - or layer.eplb_state.logical_to_physical_map is not None - or layer.eplb_state.logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for XPU.") - return layer.ipex_fusion( + return self.ipex_fusion( x, layer.use_grouped_topk, layer.top_k, @@ -368,8 +346,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) if current_platform.is_cpu(): - forward_native = forward_cpu + forward_native: Callable = forward_monolithic_cpu + apply_monolithic = forward_monolithic_cpu elif current_platform.is_xpu(): - forward_native = forward_xpu + forward_native = forward_monolithic_xpu + apply_monolithic = forward_monolithic_xpu else: forward_native = forward_cuda diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index dfdfc7ea2..e55697458 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,7 +10,6 @@ from torch.nn import Parameter import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -762,15 +761,10 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - return fused_marlin_moe( x, layer.w13_qweight, @@ -779,7 +773,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): getattr(layer, "w2_bias", None), layer.w13_scales, layer.w2_scales, - router_logits, topk_weights, topk_ids, input_global_scale1=getattr(layer, "w13_input_global_scale", None), diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 542a72810..8b6b1e445 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -6,7 +6,6 @@ from typing import Any, Union import torch from packaging import version -from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -499,16 +498,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) # TODO(bnell): Do these need to be called on the hot path? if self.quant_config.load_in_8bit: w13, w2 = self._apply_8bit_dequant(layer) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b48acdcf3..ec120cab4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoEActivationFormat, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, - FusedMoERouter, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod, ) @@ -126,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 layer: torch.nn.Module, layer_name: str, - ) -> "CompressedTensorsMoEMethod": + ) -> FusedMoEMethodBase: # FusedMoE was made by combining multiple Linears so need to # make sure quantization config for Linear can target it quant_config._add_fused_moe_to_target_scheme_map() @@ -345,19 +344,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if isinstance(x, tuple): - x_routing, _ = x - else: - x_routing = x - - topk_weights, topk_ids = router.select_experts( - hidden_states=x_routing, - router_logits=router_logits, - ) assert self.kernel is not None return self.kernel( x, @@ -639,41 +629,47 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): a2_scale=layer.w2_input_scale, ) - def apply( + @property + def is_monolithic(self) -> bool: + return ( + self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM + and not self.moe.moe_parallel_config.enable_eplb + ) + + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.is_monolithic assert layer.activation == "silu", "Only SiLU activation is supported." - - if ( + assert ( self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM and not layer.enable_eplb - ): - return flashinfer_trtllm_fp4_moe( - layer=layer, - x=x, - router_logits=router_logits, - top_k=layer.top_k, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - custom_routing_function=layer.custom_routing_function, - e_score_correction_bias=layer.e_score_correction_bias, - ) - - # Hidden_states in select_experts is only used to extract metadata - if isinstance(x, tuple): - x_routing, _ = x - else: - x_routing = x - topk_weights, topk_ids = router.select_experts( - hidden_states=x_routing, - router_logits=router_logits, ) + return flashinfer_trtllm_fp4_moe( + layer=layer, + x=x, + router_logits=router_logits, + top_k=layer.top_k, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + custom_routing_function=layer.custom_routing_function, + e_score_correction_bias=layer.e_score_correction_bias, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not self.is_monolithic + assert layer.activation == "silu", "Only SiLU activation is supported." # EPLB path if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: @@ -1059,70 +1055,73 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): block_shape=self.weight_block_size, ) - def apply( + @property + def is_monolithic(self) -> bool: + return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM + + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - if layer.enable_eplb: - raise NotImplementedError( - "EPLB not supported for `FlashInfer TRTLLM FP8 MoE`." - ) - assert layer.activation == "silu" + assert self.is_monolithic + assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM + assert layer.activation == "silu" - if self.block_quant: - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + if self.block_quant: + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - e_score_correction_bias = ( - layer.e_score_correction_bias.to(x.dtype) - if layer.e_score_correction_bias is not None - else None - ) - routing_method_type = layer.routing_method_type - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits, - routing_bias=e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.weight_block_size, - routing_method_type=routing_method_type, - routed_scaling=layer.routed_scaling_factor, - ) - else: - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) + e_score_correction_bias = ( + layer.e_score_correction_bias.to(x.dtype) + if layer.e_score_correction_bias is not None + else None + ) + routing_method_type = layer.routing_method_type + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.weight_block_size, + routing_method_type=routing_method_type, + routed_scaling=layer.routed_scaling_factor, + ) + else: + return apply_fi_trtllm_fp8_per_tensor_moe( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not self.is_monolithic assert self.kernel is not None - result = self.kernel( + return self.kernel( x, layer.w13_weight, layer.w2_weight, @@ -1137,8 +1136,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - return result - @property def supports_eplb(self) -> bool: return True @@ -1257,17 +1254,12 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - return fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -1621,15 +1613,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - return fused_marlin_moe( x, layer.w13_weight_packed, @@ -1638,7 +1625,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): None, layer.w13_weight_scale, layer.w2_weight_scale, - router_logits, topk_weights, topk_ids, input_global_scale1=getattr(layer, "w13_input_global_scale", None), @@ -1873,17 +1859,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - return fused_experts( x, layer.w13_weight_packed, @@ -2172,10 +2153,13 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): # fused_experts; quant config is not needed. return None - def apply( + @property + def is_monolithic(self) -> bool: + return True + + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: @@ -2489,19 +2473,15 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, - ): + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if layer.enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." ) assert self.moe_quant_config is not None - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_w4a8_fp8, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index c3e7a812e..974c45614 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, - FusedMoERouter, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -138,17 +137,12 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - return fused_experts( x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d9fa02c6b..fe59022cb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, - FusedMoERouter, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( @@ -968,71 +967,79 @@ class Fp8MoEMethod(FusedMoEMethodBase): def allow_inplace(self) -> bool: return True - def apply( + @property + def is_monolithic(self) -> bool: + return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM + + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - # TODO(rob): convert this to MK. - if layer.enable_eplb: - raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") - assert layer.activation == "silu", ( - f"Expected 'silu' activation but got {layer.activation}" - ) + assert self.is_monolithic + assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - if self.block_quant: - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - - e_score_correction_bias = ( - layer.e_score_correction_bias.to(x.dtype) - if layer.e_score_correction_bias is not None - else None - ) - routing_method_type = layer.routing_method_type - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits, - routing_bias=e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale_inv, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.weight_block_size, - routing_method_type=routing_method_type, - routed_scaling=layer.routed_scaling_factor, - ) - else: - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, + # TODO(rob): convert this to MK. + if layer.enable_eplb: + raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") + assert layer.activation == "silu", ( + f"Expected 'silu' activation but got {layer.activation}" ) + if self.block_quant: + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + + e_score_correction_bias = ( + layer.e_score_correction_bias.to(x.dtype) + if layer.e_score_correction_bias is not None + else None + ) + routing_method_type = layer.routing_method_type + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.weight_block_size, + routing_method_type=routing_method_type, + routed_scaling=layer.routed_scaling_factor, + ) + else: + return apply_fi_trtllm_fp8_per_tensor_moe( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.kernel is not None - result = self.kernel( + assert not self.is_monolithic + return self.kernel( x, layer.w13_weight, layer.w2_weight, @@ -1045,8 +1052,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - return result - class Fp8OnlineMoEMethod(Fp8MoEMethod): """MoE method for online FP8 quantization. diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 2b1537089..b40aebaa5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -12,7 +12,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -633,9 +632,9 @@ class GGUFMoEMethod(FusedMoEMethodBase): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." if layer.apply_router_weight_on_input: @@ -644,10 +643,6 @@ class GGUFMoEMethod(FusedMoEMethodBase): "fused GGUF MoE method." ) - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) return fused_moe_gguf( x, layer.w13_qweight, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index ec295005b..45c71f366 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,7 +10,6 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -898,15 +897,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - return fused_marlin_moe( x, layer.w13_qweight, @@ -915,7 +909,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): getattr(layer, "w2_bias", None), layer.w13_scales, layer.w2_scales, - router_logits, topk_weights, topk_ids, input_global_scale1=getattr(layer, "w13_input_global_scale", None), diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 9b2198b71..119fb2ef8 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -8,9 +8,6 @@ from packaging import version from torch.nn import Module from vllm._ipex_ops import ipex_ops as ops -from vllm.model_executor.layers.fused_moe import ( - FusedMoERouter, -) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.linear import ( LinearBase, @@ -384,10 +381,13 @@ class XPUFp8MoEMethod(Fp8OnlineMoEMethod): ) -> FusedMoEQuantConfig | None: return None - def apply( + @property + def is_monolithic(self) -> bool: + return True + + def apply_monolithic( self, layer: torch.nn.Module, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0de9cb88d..e65d23e36 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -13,7 +13,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention.layer import Attention from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -945,42 +944,49 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): a2_scale=a2_scale, ) - def apply( + @property + def is_monolithic(self) -> bool: + return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM + + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - if layer.enable_eplb: - raise NotImplementedError( - "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend." - ) - # TODO(rob): this validation should happen at kernel selection - # time in the oracle rather than here. - assert layer.activation == "silu", ( - f"Expected 'silu' activation but got {layer.activation}" + assert self.is_monolithic + assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend." ) - assert not layer.renormalize - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - - # Expert selection - topk_weights, topk_ids = router.select_experts( + # TODO(rob): this validation should happen at kernel selection + # time in the oracle rather than here. + assert layer.activation == "silu", ( + f"Expected 'silu' activation but got {layer.activation}" + ) + assert not layer.renormalize + return apply_fi_trtllm_fp8_per_tensor_moe( + layer=layer, hidden_states=x, router_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not self.is_monolithic + # TODO(rob): this validation should happen at kernel selection # time in the oracle rather than here. if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: @@ -990,7 +996,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) assert self.kernel is not None - result = self.kernel( + return self.kernel( x, layer.w13_weight, layer.w2_weight, @@ -1003,8 +1009,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - return result - ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod @@ -1629,40 +1633,47 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def supports_eplb(self) -> bool: return True - def apply( + @property + def is_monolithic(self) -> bool: + return ( + self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM + and not self.moe.moe_parallel_config.enable_eplb + ) + + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if ( + assert self.is_monolithic + assert ( self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM and not layer.enable_eplb - ): - return flashinfer_trtllm_fp4_moe( - layer=layer, - x=x, - router_logits=router_logits, - top_k=layer.top_k, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - custom_routing_function=layer.custom_routing_function, - e_score_correction_bias=layer.e_score_correction_bias, - ) - - # Hidden_states in select_experts is only used to extract metadata - if isinstance(x, tuple): - x_routing, _ = x - else: - x_routing = x - topk_weights, topk_ids = router.select_experts( - hidden_states=x_routing, - router_logits=router_logits, ) + return flashinfer_trtllm_fp4_moe( + layer=layer, + x=x, + router_logits=router_logits, + top_k=layer.top_k, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + custom_routing_function=layer.custom_routing_function, + e_score_correction_bias=layer.e_score_correction_bias, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not self.is_monolithic + # EPLB path if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: assert layer.enable_eplb diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 5d29fd01c..38340404d 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -6,7 +6,6 @@ from typing import Any, Optional import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group -from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, int4_w4a16_moe_quant_config, @@ -365,17 +364,13 @@ class MoeWNA16Method(FusedMoEMethodBase): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) return fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 01bf6664f..f2232b1db 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, - FusedMoERouter, ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( @@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def allow_inplace(self) -> bool: return True + @property + def is_monolithic(self) -> bool: + return ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + or self.mxfp4_backend == Mxfp4Backend.TRITON + ) + def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not self.is_monolithic if layer.enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") if self.mxfp4_backend == Mxfp4Backend.MARLIN: - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - return fused_marlin_moe( x, layer.w13_weight, @@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w2_bias, layer.w13_weight_scale, layer.w2_weight_scale, - router_logits, topk_weights, topk_ids, global_scale1=None, @@ -942,6 +944,98 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.eplb_state.logical_replica_count, ), "MXFP4 are not supported with this configuration." + assert ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + ) + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Backend-specific preparation + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + from flashinfer import mxfp8_quantize + + x_quant, x_scale = mxfp8_quantize(x, True, 32) + + fake_input_scale = torch.ones(self.num_experts, device=x.device) + quant_scales = [ + layer.w13_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + layer.w2_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + ] + + fi_input = x_quant + extra_kwargs = dict( + use_mxfp8_act_scaling=True, + input_sf=x_scale, + fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long), + fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long), + ) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + assert x.dtype == torch.bfloat16 + + quant_scales = [ + layer.w13_weight_scale, + layer.w2_weight_scale, + ] + + fi_input = x + extra_kwargs = dict( + use_w4_group_scaling=True, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + ) + + output = torch.empty_like(x, dtype=torch.bfloat16) + + flashinfer_cutlass_fused_moe( + input=fi_input, + token_selected_experts=topk_ids.to(torch.int).contiguous(), + token_final_scales=topk_weights, + output_dtype=torch.bfloat16, + output=output, + quant_scales=quant_scales, + fc1_expert_biases=layer.w13_bias, + fc2_expert_biases=layer.w2_bias, + swiglu_alpha=layer.gemm1_alpha, + swiglu_beta=layer.gemm1_beta, + swiglu_limit=layer.gemm1_clamp_limit, + tp_size=self.moe.tp_size, + tp_rank=self.moe.tp_rank, + ep_size=self.moe.ep_size, + ep_rank=self.moe.ep_rank, + tune_max_num_tokens=max(self.max_capture_size, 1), + **extra_kwargs, + ) + + return output + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.is_monolithic + + if layer.enable_eplb: + raise NotImplementedError("EPLB is not supported for mxfp4") + + assert _can_support_mxfp4( + layer.use_grouped_topk, + layer.topk_group, + layer.num_expert_group, + layer.expert_map, + layer.custom_routing_function, + layer.e_score_correction_bias, + layer.apply_router_weight_on_input, + layer.scoring_func, + layer.activation, + layer.eplb_state.expert_load_view, + layer.eplb_state.logical_to_physical_map, + layer.eplb_state.logical_replica_count, + ), "MXFP4 are not supported with this configuration." + if ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 @@ -988,75 +1082,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): tune_max_num_tokens=max(self.max_capture_size, 1), )[0] return trtllm_gen_output - elif ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 - ): - from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe - - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - - # Backend-specific preparation - if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: - from flashinfer import mxfp8_quantize - - x_quant, x_scale = mxfp8_quantize(x, True, 32) - - fake_input_scale = torch.ones(self.num_experts, device=x.device) - quant_scales = [ - layer.w13_weight_scale.contiguous().view(torch.int32), - fake_input_scale, - layer.w2_weight_scale.contiguous().view(torch.int32), - fake_input_scale, - ] - - fi_input = x_quant - extra_kwargs = dict( - use_mxfp8_act_scaling=True, - input_sf=x_scale, - fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long), - fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long), - ) - elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: - assert x.dtype == torch.bfloat16 - - quant_scales = [ - layer.w13_weight_scale, - layer.w2_weight_scale, - ] - - fi_input = x - extra_kwargs = dict( - use_w4_group_scaling=True, - fc1_expert_weights=layer.w13_weight, - fc2_expert_weights=layer.w2_weight, - ) - - output = torch.empty_like(x, dtype=torch.bfloat16) - _ = flashinfer_cutlass_fused_moe( - input=fi_input, - token_selected_experts=topk_ids.to(torch.int).contiguous(), - token_final_scales=topk_weights, - output_dtype=torch.bfloat16, - output=output, - quant_scales=quant_scales, - fc1_expert_biases=layer.w13_bias, - fc2_expert_biases=layer.w2_bias, - swiglu_alpha=layer.gemm1_alpha, - swiglu_beta=layer.gemm1_beta, - swiglu_limit=layer.gemm1_clamp_limit, - tp_size=self.moe.tp_size, - tp_rank=self.moe.tp_rank, - ep_size=self.moe.ep_size, - ep_rank=self.moe.ep_rank, - tune_max_num_tokens=max(self.max_capture_size, 1), - **extra_kwargs, - ) - - return output elif self.mxfp4_backend == Mxfp4Backend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward, @@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): experts_start_id=ep_rank_start, ) - def apply( + @property + def is_monolithic(self) -> bool: + return True + + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 76ecd055c..d2f0213e8 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, - FusedMoERouter, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( @@ -351,15 +350,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, @@ -388,7 +382,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): None, layer.w13_weight_scale, layer.w2_weight_scale, - router_logits, topk_weights, topk_ids, quant_type_id=scalar_types.float8_e4m3fn.id, @@ -544,15 +537,10 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, ) @@ -753,15 +741,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): def apply( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = router.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - if not self.emulate: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts,