[XPU][7/N] enable xpu fp8 moe (#34202)

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
zofia
2026-02-11 11:33:59 +08:00
committed by GitHub
parent 1485396abb
commit b482f71e9f
4 changed files with 52 additions and 5 deletions

View File

@@ -102,6 +102,7 @@ if HAS_TRITON:
)
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExperts,
XPUExpertsFp8,
)
__all__ += [
@@ -121,6 +122,7 @@ if HAS_TRITON:
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"XPUExperts",
"XPUExpertsFp8",
]
else:
# Some model classes directly use the custom ops. Add placeholders

View File

@@ -52,6 +52,7 @@ class Fp8MoeBackend(Enum):
AITER = "AITER"
VLLM_CUTLASS = "VLLM_CUTLASS"
BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
XPU = "XPU"
def backend_to_kernel_cls(
@@ -123,6 +124,13 @@ def backend_to_kernel_cls(
return CutlassBatchedExpertsFp8
elif backend == Fp8MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExpertsFp8,
)
return XPUExpertsFp8
else:
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
@@ -154,6 +162,7 @@ def select_fp8_moe_backend(
Fp8MoeBackend.TRITON,
Fp8MoeBackend.BATCHED_TRITON,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.XPU,
]
# NOTE(rob): We need to peak into the P/F selection to determine
@@ -393,6 +402,7 @@ def convert_to_fp8_moe_kernel_format(
Fp8MoeBackend.BATCHED_TRITON,
Fp8MoeBackend.VLLM_CUTLASS,
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
Fp8MoeBackend.XPU,
]:
raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")

View File

@@ -4,13 +4,16 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8DynamicTensorSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
@@ -20,6 +23,21 @@ if current_platform.is_xpu():
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_fp8 = False
@property
def expects_unquantized_inputs(self) -> bool:
return True
@@ -49,10 +67,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO: dispatch based on device.
SUPPORTED_W_A = [
(None, None),
(kFp8StaticTensorSym, None),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@@ -103,10 +121,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
xpu_fused_moe(
hidden_states=hidden_states,
w13=w1,
w13_scales=a1q_scale,
w13_scales=self.w1_scale,
w13_bias=self.w1_bias,
w2=w2,
w2_scales=a2_scale,
w2_scales=self.w2_scale,
w2_bias=self.w2_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
@@ -116,5 +134,22 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_rank=self.moe_config.ep_rank,
ep_size=self.moe_config.ep_size,
output=output,
is_fp8=self.is_fp8,
)
return
class XPUExpertsFp8(XPUExperts):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_fp8 = True