diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index ddc6588dc..9d1c8e27b 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -141,7 +141,10 @@ def backend_to_kernel_cls( return [AiterExperts] elif backend == Mxfp4MoeBackend.XPU: - raise NotImplementedError("XPU backend uses XpuMxfp4MoEMethod directly.") + from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExpertsMXFp4 + + return [XPUExpertsMXFp4] + else: raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}") @@ -156,6 +159,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend: "triton": Mxfp4MoeBackend.TRITON, "marlin": Mxfp4MoeBackend.MARLIN, "ck": Mxfp4MoeBackend.CK, + "xpu": Mxfp4MoeBackend.XPU, } if backend := mapping.get(runner_backend): return backend @@ -178,6 +182,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: Mxfp4MoeBackend.TRITON_UNFUSED, Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, + Mxfp4MoeBackend.XPU, ] return _AVAILABLE_BACKENDS @@ -351,7 +356,13 @@ def select_mxfp4_moe_backend( if current_platform.is_xpu(): backend = Mxfp4MoeBackend.XPU logger.info_once(_make_log_backend(backend)) - return backend, None + return _return_or_raise( + Mxfp4MoeBackend.XPU, + config, + kMxfp4Static, + None, + activation_format, + ) if current_platform.is_cuda() or current_platform.is_rocm(): raise NotImplementedError( @@ -741,6 +752,16 @@ def convert_to_mxfp4_moe_kernel_format( w13_bias, w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.XPU: + # No additional transformation needed for XPU backend + return ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) else: raise ValueError( f"Unsupported mxfp4_backend: {mxfp4_backend}: " diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py index b8d3ffec3..9cc0ade28 100644 --- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8DynamicTensorSym, kFp8StaticTensorSym, + kMxfp4Static, ) from vllm.platforms import current_platform @@ -38,6 +39,7 @@ class XPUExperts(mk.FusedMoEExpertsModular): num_dispatchers, ) self.is_fp8 = False + self.is_mxfp4 = False @property def expects_unquantized_inputs(self) -> bool: @@ -137,6 +139,7 @@ class XPUExperts(mk.FusedMoEExpertsModular): ep_size=self.moe_config.ep_size, output=output, is_fp8=self.is_fp8, + is_mxfp4=self.is_mxfp4, ) @@ -155,3 +158,30 @@ class XPUExpertsFp8(XPUExperts): num_dispatchers, ) self.is_fp8 = True + + +class XPUExpertsMXFp4(XPUExperts): + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + max_num_tokens: int | None = None, + num_dispatchers: int | None = None, + ): + super().__init__( + moe_config, + quant_config, + max_num_tokens, + num_dispatchers, + ) + self.is_mxfp4 = True + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kMxfp4Static, None), + ] + return (weight_key, activation_key) in SUPPORTED_W_A diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 751ee6dfd..b8b0e5f36 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, - MoEActivation, ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( @@ -33,7 +32,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ) from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import replace_parameter, set_weight_attrs -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -80,10 +78,7 @@ class Mxfp4Config(QuantizationConfig): ) return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - if current_platform.is_xpu(): - return XpuMxfp4MoEMethod(layer.moe_config) - else: - return Mxfp4MoEMethod(layer.moe_config) + return Mxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): logger.debug_once( "MXFP4 attention layer is not implemented. " @@ -420,96 +415,3 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - - -class XpuMxfp4MoEMethod(Mxfp4MoEMethod): - def __init__(self, moe_config: FusedMoEConfig): - super().__init__(moe_config) - self.moe_config = moe_config - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - super().create_weights( - layer, - num_experts, - hidden_size, - intermediate_size_per_partition, - params_dtype, - **extra_weight_attrs, - ) - self.original_hidden_size = hidden_size - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - - @property - def is_monolithic(self) -> bool: - return True - - def apply_monolithic( - self, - layer: FusedMoE, - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: - assert layer.activation == MoEActivation.SWIGLUOAI, ( - "Only swiglu_oai activation is supported for " - f"XPU MXFP4 MoE, not {layer.activation}." - ) - from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe - - M, _ = x.size() - routing_weights = torch.empty( - M, layer.top_k, dtype=torch.float32, device=x.device - ) - selected_experts = torch.empty( - M, layer.top_k, dtype=torch.int32, device=x.device - ) - token_expert_indices = torch.empty( - M, layer.top_k, dtype=torch.int32, device=x.device - ) - - if layer.use_grouped_topk: - routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk( - x, - router_logits, - layer.top_k, - layer.renormalize, - n_expert_group=layer.num_expert_group, - n_topk_group=layer.topk_group, - scoring_func=layer.scoring_func, - routed_scaling_factor=layer.routed_scaling_factor, - bias=layer.e_score_correction_bias, - ) - else: - torch.ops._moe_C.topk_softmax( - routing_weights, - selected_experts, - token_expert_indices, - router_logits, - layer.renormalize, - layer.e_score_correction_bias, - ) - - return xpu_fused_moe( - hidden_states=x, - w13=layer.w13_weight, - w13_bias=layer.w13_bias if self.moe.has_bias else None, - w13_scales=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_bias=layer.w2_bias if self.moe.has_bias else None, - w2_scales=layer.w2_weight_scale, - topk_weights=routing_weights, - topk_ids=selected_experts, - n_experts_per_token=layer.top_k, - activation=layer.activation.value, - num_experts=layer.local_num_experts, - is_mxfp4=True, - )