[XPU][4/N] add mxfp4 moe model support (#33679)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-02-06 13:03:59 +08:00
committed by GitHub
parent ac04dd374f
commit 7439e4f41b

View File

@@ -215,7 +215,7 @@ class Mxfp4Config(QuantizationConfig):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config)
return XpuMxfp4MoEMethod(layer.moe_config)
else:
quant_method = Mxfp4MoEMethod(layer.moe_config)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
@@ -1096,7 +1096,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self.moe_config = moe_config
@@ -1121,21 +1121,7 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
self.original_hidden_size = hidden_size
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
import intel_extension_for_pytorch as ipex
layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
ep_rank_start = self.moe_config.ep_rank * self.moe_config.num_local_experts
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
w1_scale_inv=layer.w13_weight_scale,
w2_scale_inv=layer.w2_weight_scale,
w13_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
is_mxfp4=True,
experts_start_id=ep_rank_start,
)
pass
@property
def is_monolithic(self) -> bool:
@@ -1148,19 +1134,55 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
router_logits: torch.Tensor,
) -> torch.Tensor:
assert layer.activation == "swigluoai", (
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
"Only swiglu_oai activation is supported for XPU MXFP4 MoE"
)
hidden_size_pad = round_up(self.original_hidden_size, 128)
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
hidden_states = layer.ipex_fusion(
x_pad,
layer.use_grouped_topk,
layer.top_k,
router_logits,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
activation="swiglu_oai",
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
M, _ = x.size()
routing_weights = torch.empty(
M, layer.top_k, dtype=torch.float32, device=x.device
)
selected_experts = torch.empty(
M, layer.top_k, dtype=torch.int32, device=x.device
)
token_expert_indices = torch.empty(
M, layer.top_k, dtype=torch.int32, device=x.device
)
if layer.use_grouped_topk:
routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
x,
router_logits,
layer.top_k,
layer.renormalize,
n_expert_group=layer.num_expert_group,
n_topk_group=layer.topk_group,
scoring_func=layer.scoring_func,
routed_scaling_factor=layer.routed_scaling_factor,
bias=layer.e_score_correction_bias,
)
else:
torch.ops._moe_C.topk_softmax(
routing_weights,
selected_experts,
token_expert_indices,
router_logits,
layer.renormalize,
layer.e_score_correction_bias,
)
return xpu_fused_moe(
hidden_states=x,
w13=layer.w13_weight,
w13_bias=layer.w13_bias if self.moe.has_bias else None,
w13_scales=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_bias=layer.w2_bias if self.moe.has_bias else None,
w2_scales=layer.w2_weight_scale,
topk_weights=routing_weights,
topk_ids=selected_experts,
n_experts_per_token=layer.top_k,
activation=layer.activation,
num_experts=layer.local_num_experts,
is_mxfp4=True,
)
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
return hidden_states