diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 50009445d..b9dec4530 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -215,7 +215,7 @@ class Mxfp4Config(QuantizationConfig): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): if current_platform.is_xpu(): - return IpexMxfp4MoEMethod(layer.moe_config) + return XpuMxfp4MoEMethod(layer.moe_config) else: quant_method = Mxfp4MoEMethod(layer.moe_config) quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) @@ -1096,7 +1096,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") -class IpexMxfp4MoEMethod(Mxfp4MoEMethod): +class XpuMxfp4MoEMethod(Mxfp4MoEMethod): def __init__(self, moe_config: FusedMoEConfig): super().__init__(moe_config) self.moe_config = moe_config @@ -1121,21 +1121,7 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): self.original_hidden_size = hidden_size def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - import intel_extension_for_pytorch as ipex - - layer.w13_weight.data = layer.w13_weight.data.view(torch.int32) - layer.w2_weight.data = layer.w2_weight.data.view(torch.int32) - ep_rank_start = self.moe_config.ep_rank * self.moe_config.num_local_experts - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - w1_scale_inv=layer.w13_weight_scale, - w2_scale_inv=layer.w2_weight_scale, - w13_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - is_mxfp4=True, - experts_start_id=ep_rank_start, - ) + pass @property def is_monolithic(self) -> bool: @@ -1148,19 +1134,55 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): router_logits: torch.Tensor, ) -> torch.Tensor: assert layer.activation == "swigluoai", ( - "Only swiglu_oai activation is supported for IPEX MXFP4 MoE" + "Only swiglu_oai activation is supported for XPU MXFP4 MoE" ) - hidden_size_pad = round_up(self.original_hidden_size, 128) - x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1))) - hidden_states = layer.ipex_fusion( - x_pad, - layer.use_grouped_topk, - layer.top_k, - router_logits, - layer.renormalize, - layer.topk_group, - layer.num_expert_group, - activation="swiglu_oai", + from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe + + M, _ = x.size() + routing_weights = torch.empty( + M, layer.top_k, dtype=torch.float32, device=x.device + ) + selected_experts = torch.empty( + M, layer.top_k, dtype=torch.int32, device=x.device + ) + token_expert_indices = torch.empty( + M, layer.top_k, dtype=torch.int32, device=x.device + ) + + if layer.use_grouped_topk: + routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk( + x, + router_logits, + layer.top_k, + layer.renormalize, + n_expert_group=layer.num_expert_group, + n_topk_group=layer.topk_group, + scoring_func=layer.scoring_func, + routed_scaling_factor=layer.routed_scaling_factor, + bias=layer.e_score_correction_bias, + ) + else: + torch.ops._moe_C.topk_softmax( + routing_weights, + selected_experts, + token_expert_indices, + router_logits, + layer.renormalize, + layer.e_score_correction_bias, + ) + + return xpu_fused_moe( + hidden_states=x, + w13=layer.w13_weight, + w13_bias=layer.w13_bias if self.moe.has_bias else None, + w13_scales=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_bias=layer.w2_bias if self.moe.has_bias else None, + w2_scales=layer.w2_weight_scale, + topk_weights=routing_weights, + topk_ids=selected_experts, + n_experts_per_token=layer.top_k, + activation=layer.activation, + num_experts=layer.local_num_experts, + is_mxfp4=True, ) - hidden_states = hidden_states[..., : self.original_hidden_size].contiguous() - return hidden_states