[XPU][MoE Refactor] Refactor xpu mxfp4 support into oracle (#37784)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -141,7 +141,10 @@ def backend_to_kernel_cls(
|
||||
return [AiterExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.XPU:
|
||||
raise NotImplementedError("XPU backend uses XpuMxfp4MoEMethod directly.")
|
||||
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExpertsMXFp4
|
||||
|
||||
return [XPUExpertsMXFp4]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
|
||||
|
||||
@@ -156,6 +159,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
|
||||
"triton": Mxfp4MoeBackend.TRITON,
|
||||
"marlin": Mxfp4MoeBackend.MARLIN,
|
||||
"ck": Mxfp4MoeBackend.CK,
|
||||
"xpu": Mxfp4MoeBackend.XPU,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
return backend
|
||||
@@ -178,6 +182,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
|
||||
Mxfp4MoeBackend.TRITON_UNFUSED,
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
Mxfp4MoeBackend.BATCHED_MARLIN,
|
||||
Mxfp4MoeBackend.XPU,
|
||||
]
|
||||
return _AVAILABLE_BACKENDS
|
||||
|
||||
@@ -351,7 +356,13 @@ def select_mxfp4_moe_backend(
|
||||
if current_platform.is_xpu():
|
||||
backend = Mxfp4MoeBackend.XPU
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, None
|
||||
return _return_or_raise(
|
||||
Mxfp4MoeBackend.XPU,
|
||||
config,
|
||||
kMxfp4Static,
|
||||
None,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda() or current_platform.is_rocm():
|
||||
raise NotImplementedError(
|
||||
@@ -741,6 +752,16 @@ def convert_to_mxfp4_moe_kernel_format(
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
elif mxfp4_backend == Mxfp4MoeBackend.XPU:
|
||||
# No additional transformation needed for XPU backend
|
||||
return (
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mxfp4_backend: {mxfp4_backend}: "
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -38,6 +39,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
|
||||
num_dispatchers,
|
||||
)
|
||||
self.is_fp8 = False
|
||||
self.is_mxfp4 = False
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
@@ -137,6 +139,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
|
||||
ep_size=self.moe_config.ep_size,
|
||||
output=output,
|
||||
is_fp8=self.is_fp8,
|
||||
is_mxfp4=self.is_mxfp4,
|
||||
)
|
||||
|
||||
|
||||
@@ -155,3 +158,30 @@ class XPUExpertsFp8(XPUExperts):
|
||||
num_dispatchers,
|
||||
)
|
||||
self.is_fp8 = True
|
||||
|
||||
|
||||
class XPUExpertsMXFp4(XPUExperts):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
max_num_tokens: int | None = None,
|
||||
num_dispatchers: int | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config,
|
||||
quant_config,
|
||||
max_num_tokens,
|
||||
num_dispatchers,
|
||||
)
|
||||
self.is_mxfp4 = True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
SUPPORTED_W_A = [
|
||||
(kMxfp4Static, None),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -33,7 +32,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -80,10 +78,7 @@ class Mxfp4Config(QuantizationConfig):
|
||||
)
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if current_platform.is_xpu():
|
||||
return XpuMxfp4MoEMethod(layer.moe_config)
|
||||
else:
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
elif isinstance(layer, Attention):
|
||||
logger.debug_once(
|
||||
"MXFP4 attention layer is not implemented. "
|
||||
@@ -420,96 +415,3 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self.moe_config = moe_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
super().create_weights(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
self.original_hidden_size = hidden_size
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert layer.activation == MoEActivation.SWIGLUOAI, (
|
||||
"Only swiglu_oai activation is supported for "
|
||||
f"XPU MXFP4 MoE, not {layer.activation}."
|
||||
)
|
||||
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
|
||||
|
||||
M, _ = x.size()
|
||||
routing_weights = torch.empty(
|
||||
M, layer.top_k, dtype=torch.float32, device=x.device
|
||||
)
|
||||
selected_experts = torch.empty(
|
||||
M, layer.top_k, dtype=torch.int32, device=x.device
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
M, layer.top_k, dtype=torch.int32, device=x.device
|
||||
)
|
||||
|
||||
if layer.use_grouped_topk:
|
||||
routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
|
||||
x,
|
||||
router_logits,
|
||||
layer.top_k,
|
||||
layer.renormalize,
|
||||
n_expert_group=layer.num_expert_group,
|
||||
n_topk_group=layer.topk_group,
|
||||
scoring_func=layer.scoring_func,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
bias=layer.e_score_correction_bias,
|
||||
)
|
||||
else:
|
||||
torch.ops._moe_C.topk_softmax(
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
token_expert_indices,
|
||||
router_logits,
|
||||
layer.renormalize,
|
||||
layer.e_score_correction_bias,
|
||||
)
|
||||
|
||||
return xpu_fused_moe(
|
||||
hidden_states=x,
|
||||
w13=layer.w13_weight,
|
||||
w13_bias=layer.w13_bias if self.moe.has_bias else None,
|
||||
w13_scales=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_bias=layer.w2_bias if self.moe.has_bias else None,
|
||||
w2_scales=layer.w2_weight_scale,
|
||||
topk_weights=routing_weights,
|
||||
topk_ids=selected_experts,
|
||||
n_experts_per_token=layer.top_k,
|
||||
activation=layer.activation.value,
|
||||
num_experts=layer.local_num_experts,
|
||||
is_mxfp4=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user