[MoE Refactor] Convert mxfp4 marlin into modular kernel format (#34588)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_intermediate_size,
|
||||
marlin_quant_input,
|
||||
@@ -550,6 +551,8 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
|
||||
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
|
||||
self.is_k_full = is_k_full
|
||||
self.input_dtype = get_marlin_input_dtype()
|
||||
|
||||
super().__init__(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -736,6 +739,7 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
sort_indices1=self.w13_g_idx_sort_indices,
|
||||
sort_indices2=self.w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
|
||||
@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
mxfp4_mxfp8_moe_quant_config,
|
||||
@@ -25,7 +28,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
@@ -52,7 +54,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.math_utils import round_up
|
||||
@@ -219,7 +220,6 @@ class Mxfp4Config(QuantizationConfig):
|
||||
return XpuMxfp4MoEMethod(layer.moe_config)
|
||||
else:
|
||||
quant_method = Mxfp4MoEMethod(layer.moe_config)
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, Attention):
|
||||
# TODO: Add support for MXFP4 Attention.
|
||||
@@ -243,7 +243,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.weight_dtype = "mxfp4"
|
||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
self.max_capture_size = (
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
@@ -254,6 +253,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
"Please check your environment and try again."
|
||||
)
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -408,7 +408,30 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
|
||||
prepare_moe_fp4_layer_for_marlin(
|
||||
layer, input_dtype=get_marlin_input_dtype()
|
||||
)
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
allow_new_interface=True,
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
self.moe_mk = mk.FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
MarlinExperts(
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
),
|
||||
inplace=not self.moe.disable_inplace,
|
||||
shared_experts=None,
|
||||
)
|
||||
elif (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
@@ -910,27 +933,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=None,
|
||||
global_scale2=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
activation=layer.activation,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
assert self.moe_mk is not None
|
||||
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
assert _can_support_mxfp4(
|
||||
layer.use_grouped_topk,
|
||||
layer.topk_group,
|
||||
|
||||
Reference in New Issue
Block a user