[XPU][4/N] add mxfp4 moe model support (#33679)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -215,7 +215,7 @@ class Mxfp4Config(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if current_platform.is_xpu():
|
||||
return IpexMxfp4MoEMethod(layer.moe_config)
|
||||
return XpuMxfp4MoEMethod(layer.moe_config)
|
||||
else:
|
||||
quant_method = Mxfp4MoEMethod(layer.moe_config)
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
@@ -1096,7 +1096,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
||||
|
||||
class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self.moe_config = moe_config
|
||||
@@ -1121,21 +1121,7 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
self.original_hidden_size = hidden_size
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
|
||||
layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
|
||||
ep_rank_start = self.moe_config.ep_rank * self.moe_config.num_local_experts
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
w1_scale_inv=layer.w13_weight_scale,
|
||||
w2_scale_inv=layer.w2_weight_scale,
|
||||
w13_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
is_mxfp4=True,
|
||||
experts_start_id=ep_rank_start,
|
||||
)
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
@@ -1148,19 +1134,55 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert layer.activation == "swigluoai", (
|
||||
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
|
||||
"Only swiglu_oai activation is supported for XPU MXFP4 MoE"
|
||||
)
|
||||
hidden_size_pad = round_up(self.original_hidden_size, 128)
|
||||
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
|
||||
hidden_states = layer.ipex_fusion(
|
||||
x_pad,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
activation="swiglu_oai",
|
||||
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
|
||||
|
||||
M, _ = x.size()
|
||||
routing_weights = torch.empty(
|
||||
M, layer.top_k, dtype=torch.float32, device=x.device
|
||||
)
|
||||
selected_experts = torch.empty(
|
||||
M, layer.top_k, dtype=torch.int32, device=x.device
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
M, layer.top_k, dtype=torch.int32, device=x.device
|
||||
)
|
||||
|
||||
if layer.use_grouped_topk:
|
||||
routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
|
||||
x,
|
||||
router_logits,
|
||||
layer.top_k,
|
||||
layer.renormalize,
|
||||
n_expert_group=layer.num_expert_group,
|
||||
n_topk_group=layer.topk_group,
|
||||
scoring_func=layer.scoring_func,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
bias=layer.e_score_correction_bias,
|
||||
)
|
||||
else:
|
||||
torch.ops._moe_C.topk_softmax(
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
token_expert_indices,
|
||||
router_logits,
|
||||
layer.renormalize,
|
||||
layer.e_score_correction_bias,
|
||||
)
|
||||
|
||||
return xpu_fused_moe(
|
||||
hidden_states=x,
|
||||
w13=layer.w13_weight,
|
||||
w13_bias=layer.w13_bias if self.moe.has_bias else None,
|
||||
w13_scales=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_bias=layer.w2_bias if self.moe.has_bias else None,
|
||||
w2_scales=layer.w2_weight_scale,
|
||||
topk_weights=routing_weights,
|
||||
topk_ids=selected_experts,
|
||||
n_experts_per_token=layer.top_k,
|
||||
activation=layer.activation,
|
||||
num_experts=layer.local_num_experts,
|
||||
is_mxfp4=True,
|
||||
)
|
||||
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user