[XPU][7/N] enable xpu fp8 moe (#34202)

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
zofia
2026-02-11 11:33:59 +08:00
committed by GitHub
parent 1485396abb
commit b482f71e9f
4 changed files with 52 additions and 5 deletions

View File

@@ -15,4 +15,4 @@ torch==2.10.0+xpu
torchaudio torchaudio
torchvision torchvision
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.1/vllm_xpu_kernels-0.1.1-cp312-cp312-linux_x86_64.whl vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.2/vllm_xpu_kernels-0.1.2-cp312-cp312-linux_x86_64.whl

View File

@@ -102,6 +102,7 @@ if HAS_TRITON:
) )
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import ( from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExperts, XPUExperts,
XPUExpertsFp8,
) )
__all__ += [ __all__ += [
@@ -121,6 +122,7 @@ if HAS_TRITON:
"BatchedDeepGemmExperts", "BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts", "TritonOrDeepGemmExperts",
"XPUExperts", "XPUExperts",
"XPUExpertsFp8",
] ]
else: else:
# Some model classes directly use the custom ops. Add placeholders # Some model classes directly use the custom ops. Add placeholders

View File

@@ -52,6 +52,7 @@ class Fp8MoeBackend(Enum):
AITER = "AITER" AITER = "AITER"
VLLM_CUTLASS = "VLLM_CUTLASS" VLLM_CUTLASS = "VLLM_CUTLASS"
BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS" BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
XPU = "XPU"
def backend_to_kernel_cls( def backend_to_kernel_cls(
@@ -123,6 +124,13 @@ def backend_to_kernel_cls(
return CutlassBatchedExpertsFp8 return CutlassBatchedExpertsFp8
elif backend == Fp8MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExpertsFp8,
)
return XPUExpertsFp8
else: else:
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
@@ -154,6 +162,7 @@ def select_fp8_moe_backend(
Fp8MoeBackend.TRITON, Fp8MoeBackend.TRITON,
Fp8MoeBackend.BATCHED_TRITON, Fp8MoeBackend.BATCHED_TRITON,
Fp8MoeBackend.MARLIN, Fp8MoeBackend.MARLIN,
Fp8MoeBackend.XPU,
] ]
# NOTE(rob): We need to peak into the P/F selection to determine # NOTE(rob): We need to peak into the P/F selection to determine
@@ -393,6 +402,7 @@ def convert_to_fp8_moe_kernel_format(
Fp8MoeBackend.BATCHED_TRITON, Fp8MoeBackend.BATCHED_TRITON,
Fp8MoeBackend.VLLM_CUTLASS, Fp8MoeBackend.VLLM_CUTLASS,
Fp8MoeBackend.BATCHED_VLLM_CUTLASS, Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
Fp8MoeBackend.XPU,
]: ]:
raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}") raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")

View File

@@ -4,13 +4,16 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8DynamicTensorSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -20,6 +23,21 @@ if current_platform.is_xpu():
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute): class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_fp8 = False
@property @property
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(self) -> bool:
return True return True
@@ -49,10 +67,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> bool: ) -> bool:
# TODO: dispatch based on device.
SUPPORTED_W_A = [ SUPPORTED_W_A = [
(None, None), (None, None),
(kFp8StaticTensorSym, None), (kFp8StaticTensorSym, None),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
] ]
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
@@ -103,10 +121,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
xpu_fused_moe( xpu_fused_moe(
hidden_states=hidden_states, hidden_states=hidden_states,
w13=w1, w13=w1,
w13_scales=a1q_scale, w13_scales=self.w1_scale,
w13_bias=self.w1_bias, w13_bias=self.w1_bias,
w2=w2, w2=w2,
w2_scales=a2_scale, w2_scales=self.w2_scale,
w2_bias=self.w2_bias, w2_bias=self.w2_bias,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
@@ -116,5 +134,22 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_rank=self.moe_config.ep_rank, ep_rank=self.moe_config.ep_rank,
ep_size=self.moe_config.ep_size, ep_size=self.moe_config.ep_size,
output=output, output=output,
is_fp8=self.is_fp8,
) )
return
class XPUExpertsFp8(XPUExperts):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_fp8 = True