[MoE Refactor] Convert mxfp4 marlin into modular kernel format (#34588)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Yongye Zhu
2026-02-18 17:37:14 -05:00
committed by GitHub
parent 8d9babd4de
commit 40da9625a1
2 changed files with 44 additions and 25 deletions

View File

@@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
marlin_make_workspace_new,
marlin_moe_intermediate_size,
marlin_quant_input,
@@ -550,6 +551,8 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full
self.input_dtype = get_marlin_input_dtype()
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
@@ -736,6 +739,7 @@ class MarlinExperts(MarlinExpertsBase):
sort_indices1=self.w13_g_idx_sort_indices,
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:

View File

@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe import (
MoEActivation,
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_mxfp8_moe_quant_config,
@@ -25,7 +28,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
@@ -52,7 +54,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import round_up
@@ -219,7 +220,6 @@ class Mxfp4Config(QuantizationConfig):
return XpuMxfp4MoEMethod(layer.moe_config)
else:
quant_method = Mxfp4MoEMethod(layer.moe_config)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention.
@@ -243,7 +243,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.weight_dtype = "mxfp4"
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
@@ -254,6 +253,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"Please check your environment and try again."
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
self.moe_mk: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
@@ -408,7 +408,30 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
prepare_moe_fp4_layer_for_marlin(
layer, input_dtype=get_marlin_input_dtype()
)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
prepare_finalize = maybe_make_prepare_finalize(
moe=self.moe,
quant_config=self.moe_quant_config,
routing_tables=layer._maybe_init_expert_routing_tables(),
allow_new_interface=True,
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
prepare_finalize,
MarlinExperts(
self.moe,
self.moe_quant_config,
),
inplace=not self.moe.disable_inplace,
shared_experts=None,
)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
@@ -910,27 +933,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_bias,
layer.w2_bias,
layer.w13_weight_scale,
layer.w2_weight_scale,
topk_weights,
topk_ids,
global_scale1=None,
global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
inplace=not self.moe.disable_inplace,
)
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
assert _can_support_mxfp4(
layer.use_grouped_topk,
layer.topk_group,