diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 57fb3561d..e5f32ebd1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( disable_inplace, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, marlin_make_workspace_new, marlin_moe_intermediate_size, marlin_quant_input, @@ -550,6 +551,8 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): self.w13_g_idx_sort_indices = w13_g_idx_sort_indices self.w2_g_idx_sort_indices = w2_g_idx_sort_indices self.is_k_full = is_k_full + self.input_dtype = get_marlin_input_dtype() + super().__init__( moe_config=moe_config, quant_config=quant_config, @@ -736,6 +739,7 @@ class MarlinExperts(MarlinExpertsBase): sort_indices1=self.w13_g_idx_sort_indices, sort_indices2=self.w2_g_idx_sort_indices, is_k_full=self.is_k_full, + input_dtype=self.input_dtype, ) def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 5c6837e7a..4b24885b4 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe import ( MoEActivation, ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, mxfp4_mxfp8_moe_quant_config, @@ -25,7 +28,6 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( BatchedMarlinExperts, MarlinExperts, - fused_marlin_moe, ) from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, @@ -52,7 +54,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types from vllm.utils.flashinfer import has_flashinfer from vllm.utils.import_utils import has_triton_kernels from vllm.utils.math_utils import round_up @@ -219,7 +220,6 @@ class Mxfp4Config(QuantizationConfig): return XpuMxfp4MoEMethod(layer.moe_config) else: quant_method = Mxfp4MoEMethod(layer.moe_config) - quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method elif isinstance(layer, Attention): # TODO: Add support for MXFP4 Attention. @@ -243,7 +243,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.weight_dtype = "mxfp4" self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) - self.marlin_input_dtype = None self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) @@ -254,6 +253,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): "Please check your environment and try again." ) self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + self.moe_mk: mk.FusedMoEModularKernel | None = None def create_weights( self, @@ -408,7 +408,30 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer): if self.mxfp4_backend == Mxfp4Backend.MARLIN: - prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype) + prepare_moe_fp4_layer_for_marlin( + layer, input_dtype=get_marlin_input_dtype() + ) + + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + assert self.moe_quant_config is not None + + prepare_finalize = maybe_make_prepare_finalize( + moe=self.moe, + quant_config=self.moe_quant_config, + routing_tables=layer._maybe_init_expert_routing_tables(), + allow_new_interface=True, + ) + assert prepare_finalize is not None + + self.moe_mk = mk.FusedMoEModularKernel( + prepare_finalize, + MarlinExperts( + self.moe, + self.moe_quant_config, + ), + inplace=not self.moe.disable_inplace, + shared_experts=None, + ) elif ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 @@ -910,27 +933,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): raise NotImplementedError("EPLB is not supported for mxfp4") if self.mxfp4_backend == Mxfp4Backend.MARLIN: - return fused_marlin_moe( - x, - layer.w13_weight, - layer.w2_weight, - layer.w13_bias, - layer.w2_bias, - layer.w13_weight_scale, - layer.w2_weight_scale, - topk_weights, - topk_ids, - global_scale1=None, - global_scale2=None, - quant_type_id=scalar_types.float4_e2m1f.id, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - activation=layer.activation, - expert_map=layer.expert_map, - input_dtype=self.marlin_input_dtype, - inplace=not self.moe.disable_inplace, - ) + assert self.moe_mk is not None + return self.moe_mk( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) assert _can_support_mxfp4( layer.use_grouped_topk, layer.topk_group,