[XPU][7/N] enable xpu fp8 moe (#34202)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
@@ -102,6 +102,7 @@ if HAS_TRITON:
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
||||
XPUExperts,
|
||||
XPUExpertsFp8,
|
||||
)
|
||||
|
||||
__all__ += [
|
||||
@@ -121,6 +122,7 @@ if HAS_TRITON:
|
||||
"BatchedDeepGemmExperts",
|
||||
"TritonOrDeepGemmExperts",
|
||||
"XPUExperts",
|
||||
"XPUExpertsFp8",
|
||||
]
|
||||
else:
|
||||
# Some model classes directly use the custom ops. Add placeholders
|
||||
|
||||
@@ -52,6 +52,7 @@ class Fp8MoeBackend(Enum):
|
||||
AITER = "AITER"
|
||||
VLLM_CUTLASS = "VLLM_CUTLASS"
|
||||
BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
|
||||
XPU = "XPU"
|
||||
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
@@ -123,6 +124,13 @@ def backend_to_kernel_cls(
|
||||
|
||||
return CutlassBatchedExpertsFp8
|
||||
|
||||
elif backend == Fp8MoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
||||
XPUExpertsFp8,
|
||||
)
|
||||
|
||||
return XPUExpertsFp8
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
|
||||
|
||||
@@ -154,6 +162,7 @@ def select_fp8_moe_backend(
|
||||
Fp8MoeBackend.TRITON,
|
||||
Fp8MoeBackend.BATCHED_TRITON,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
Fp8MoeBackend.XPU,
|
||||
]
|
||||
|
||||
# NOTE(rob): We need to peak into the P/F selection to determine
|
||||
@@ -393,6 +402,7 @@ def convert_to_fp8_moe_kernel_format(
|
||||
Fp8MoeBackend.BATCHED_TRITON,
|
||||
Fp8MoeBackend.VLLM_CUTLASS,
|
||||
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
|
||||
Fp8MoeBackend.XPU,
|
||||
]:
|
||||
raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")
|
||||
|
||||
|
||||
@@ -4,13 +4,16 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -20,6 +23,21 @@ if current_platform.is_xpu():
|
||||
|
||||
|
||||
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
max_num_tokens: int | None = None,
|
||||
num_dispatchers: int | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config,
|
||||
quant_config,
|
||||
max_num_tokens,
|
||||
num_dispatchers,
|
||||
)
|
||||
self.is_fp8 = False
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
@@ -49,10 +67,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# TODO: dispatch based on device.
|
||||
SUPPORTED_W_A = [
|
||||
(None, None),
|
||||
(kFp8StaticTensorSym, None),
|
||||
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@@ -103,10 +121,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
xpu_fused_moe(
|
||||
hidden_states=hidden_states,
|
||||
w13=w1,
|
||||
w13_scales=a1q_scale,
|
||||
w13_scales=self.w1_scale,
|
||||
w13_bias=self.w1_bias,
|
||||
w2=w2,
|
||||
w2_scales=a2_scale,
|
||||
w2_scales=self.w2_scale,
|
||||
w2_bias=self.w2_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
@@ -116,5 +134,22 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
ep_rank=self.moe_config.ep_rank,
|
||||
ep_size=self.moe_config.ep_size,
|
||||
output=output,
|
||||
is_fp8=self.is_fp8,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class XPUExpertsFp8(XPUExperts):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
max_num_tokens: int | None = None,
|
||||
num_dispatchers: int | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config,
|
||||
quant_config,
|
||||
max_num_tokens,
|
||||
num_dispatchers,
|
||||
)
|
||||
self.is_fp8 = True
|
||||
|
||||
Reference in New Issue
Block a user