[XPU][MoE Refactor] Refactor xpu mxfp4 support into oracle (#37784)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-23 19:10:41 +08:00
committed by GitHub
parent 9ace378a63
commit debd6e768c
3 changed files with 54 additions and 101 deletions

View File

@@ -141,7 +141,10 @@ def backend_to_kernel_cls(
return [AiterExperts]
elif backend == Mxfp4MoeBackend.XPU:
raise NotImplementedError("XPU backend uses XpuMxfp4MoEMethod directly.")
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExpertsMXFp4
return [XPUExpertsMXFp4]
else:
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
@@ -156,6 +159,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
"triton": Mxfp4MoeBackend.TRITON,
"marlin": Mxfp4MoeBackend.MARLIN,
"ck": Mxfp4MoeBackend.CK,
"xpu": Mxfp4MoeBackend.XPU,
}
if backend := mapping.get(runner_backend):
return backend
@@ -178,6 +182,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.XPU,
]
return _AVAILABLE_BACKENDS
@@ -351,7 +356,13 @@ def select_mxfp4_moe_backend(
if current_platform.is_xpu():
backend = Mxfp4MoeBackend.XPU
logger.info_once(_make_log_backend(backend))
return backend, None
return _return_or_raise(
Mxfp4MoeBackend.XPU,
config,
kMxfp4Static,
None,
activation_format,
)
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
@@ -741,6 +752,16 @@ def convert_to_mxfp4_moe_kernel_format(
w13_bias,
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.XPU:
# No additional transformation needed for XPU backend
return (
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend: {mxfp4_backend}: "

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8DynamicTensorSym,
kFp8StaticTensorSym,
kMxfp4Static,
)
from vllm.platforms import current_platform
@@ -38,6 +39,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
num_dispatchers,
)
self.is_fp8 = False
self.is_mxfp4 = False
@property
def expects_unquantized_inputs(self) -> bool:
@@ -137,6 +139,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
ep_size=self.moe_config.ep_size,
output=output,
is_fp8=self.is_fp8,
is_mxfp4=self.is_mxfp4,
)
@@ -155,3 +158,30 @@ class XPUExpertsFp8(XPUExperts):
num_dispatchers,
)
self.is_fp8 = True
class XPUExpertsMXFp4(XPUExperts):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_mxfp4 = True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A

View File

@@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
MoEActivation,
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
@@ -33,7 +32,6 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -80,10 +78,7 @@ class Mxfp4Config(QuantizationConfig):
)
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
if current_platform.is_xpu():
return XpuMxfp4MoEMethod(layer.moe_config)
else:
return Mxfp4MoEMethod(layer.moe_config)
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
logger.debug_once(
"MXFP4 attention layer is not implemented. "
@@ -420,96 +415,3 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self.moe_config = moe_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
self.original_hidden_size = hidden_size
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert layer.activation == MoEActivation.SWIGLUOAI, (
"Only swiglu_oai activation is supported for "
f"XPU MXFP4 MoE, not {layer.activation}."
)
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
M, _ = x.size()
routing_weights = torch.empty(
M, layer.top_k, dtype=torch.float32, device=x.device
)
selected_experts = torch.empty(
M, layer.top_k, dtype=torch.int32, device=x.device
)
token_expert_indices = torch.empty(
M, layer.top_k, dtype=torch.int32, device=x.device
)
if layer.use_grouped_topk:
routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
x,
router_logits,
layer.top_k,
layer.renormalize,
n_expert_group=layer.num_expert_group,
n_topk_group=layer.topk_group,
scoring_func=layer.scoring_func,
routed_scaling_factor=layer.routed_scaling_factor,
bias=layer.e_score_correction_bias,
)
else:
torch.ops._moe_C.topk_softmax(
routing_weights,
selected_experts,
token_expert_indices,
router_logits,
layer.renormalize,
layer.e_score_correction_bias,
)
return xpu_fused_moe(
hidden_states=x,
w13=layer.w13_weight,
w13_bias=layer.w13_bias if self.moe.has_bias else None,
w13_scales=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_bias=layer.w2_bias if self.moe.has_bias else None,
w2_scales=layer.w2_weight_scale,
topk_weights=routing_weights,
topk_ids=selected_experts,
n_experts_per_token=layer.top_k,
activation=layer.activation.value,
num_experts=layer.local_num_experts,
is_mxfp4=True,
)