From a3f2f40947cfcf9446ab967cb90d9ebc92833f30 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Fri, 2 Jan 2026 13:54:50 -0800 Subject: [PATCH] [MoE Refactor] Explicit construct mk for flashinfer bf16 kernel (#31504) Signed-off-by: Yongye Zhu Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- .../fused_moe/unquantized_fused_moe_method.py | 68 +++++++++---------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 82dbccf3f..029edc44c 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, biased_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, +) from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) @@ -27,7 +30,10 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + swap_w13_to_w31, +) +from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe @@ -73,18 +79,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): logger.info_once( "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" ) - from functools import partial - - from .flashinfer_cutlass_moe import flashinfer_cutlass_moe - - self.flashinfer_cutlass_moe = partial( - flashinfer_cutlass_moe, - quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, - tp_rank=self.moe.moe_parallel_config.tp_rank, - tp_size=self.moe.moe_parallel_config.tp_size, - ep_rank=self.moe.moe_parallel_config.ep_rank, - ep_size=self.moe.moe_parallel_config.ep_size, - ) else: if ( self.moe.moe_parallel_config.use_ep @@ -101,7 +95,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): "FlashInfer CUTLASS MoE is currently not available for DP.", scope="local", ) - self.flashinfer_cutlass_moe = None # type: ignore @property def supports_eplb(self) -> bool: @@ -222,12 +215,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w13_weight.data = shuffled_w13 layer.w2_weight.data = shuffled_w2 - if self.flashinfer_cutlass_moe_enabled: - # Swap halves to arrange as [w3; w1] (kernel expectation) - w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) - w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) - layer.w13_weight.data = w13_weight_swapped.contiguous() - if current_platform.is_xpu(): import intel_extension_for_pytorch as ipex @@ -271,11 +258,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) elif current_platform.is_cuda_alike(): self.moe_quant_config = self.get_fused_moe_quant_config(layer) - self.kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonExperts(self.moe_quant_config), - shared_experts=None, - ) + if self.flashinfer_cutlass_moe_enabled: + self.use_inplace = False + # Swap halves to arrange as [w3; w1] (kernel expectation) + w13_weight = swap_w13_to_w31(layer.w13_weight.data) + replace_parameter(layer, "w13_weight", w13_weight) + + self.kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + FlashInferExperts( + out_dtype=layer.params_dtype, + quant_config=self.moe_quant_config, + tp_rank=self.moe.moe_parallel_config.tp_rank, + tp_size=self.moe.moe_parallel_config.tp_size, + ep_rank=self.moe.moe_parallel_config.ep_rank, + ep_size=self.moe.moe_parallel_config.ep_size, + ), + ) + else: + self.use_inplace = True + self.kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonExperts(self.moe_quant_config), + shared_experts=None, + ) def apply( self, @@ -320,16 +326,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - elif self.flashinfer_cutlass_moe_enabled: - return self.flashinfer_cutlass_moe( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) else: result = self.kernel( hidden_states=x, @@ -337,7 +333,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=self.use_inplace, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts,