[XPU][7/N] enable xpu fp8 moe (#34202)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
@@ -15,4 +15,4 @@ torch==2.10.0+xpu
|
|||||||
torchaudio
|
torchaudio
|
||||||
torchvision
|
torchvision
|
||||||
|
|
||||||
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.1/vllm_xpu_kernels-0.1.1-cp312-cp312-linux_x86_64.whl
|
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.2/vllm_xpu_kernels-0.1.2-cp312-cp312-linux_x86_64.whl
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ if HAS_TRITON:
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
||||||
XPUExperts,
|
XPUExperts,
|
||||||
|
XPUExpertsFp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ += [
|
__all__ += [
|
||||||
@@ -121,6 +122,7 @@ if HAS_TRITON:
|
|||||||
"BatchedDeepGemmExperts",
|
"BatchedDeepGemmExperts",
|
||||||
"TritonOrDeepGemmExperts",
|
"TritonOrDeepGemmExperts",
|
||||||
"XPUExperts",
|
"XPUExperts",
|
||||||
|
"XPUExpertsFp8",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# Some model classes directly use the custom ops. Add placeholders
|
# Some model classes directly use the custom ops. Add placeholders
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class Fp8MoeBackend(Enum):
|
|||||||
AITER = "AITER"
|
AITER = "AITER"
|
||||||
VLLM_CUTLASS = "VLLM_CUTLASS"
|
VLLM_CUTLASS = "VLLM_CUTLASS"
|
||||||
BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
|
BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
|
||||||
|
XPU = "XPU"
|
||||||
|
|
||||||
|
|
||||||
def backend_to_kernel_cls(
|
def backend_to_kernel_cls(
|
||||||
@@ -123,6 +124,13 @@ def backend_to_kernel_cls(
|
|||||||
|
|
||||||
return CutlassBatchedExpertsFp8
|
return CutlassBatchedExpertsFp8
|
||||||
|
|
||||||
|
elif backend == Fp8MoeBackend.XPU:
|
||||||
|
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
||||||
|
XPUExpertsFp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
return XPUExpertsFp8
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
|
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
|
||||||
|
|
||||||
@@ -154,6 +162,7 @@ def select_fp8_moe_backend(
|
|||||||
Fp8MoeBackend.TRITON,
|
Fp8MoeBackend.TRITON,
|
||||||
Fp8MoeBackend.BATCHED_TRITON,
|
Fp8MoeBackend.BATCHED_TRITON,
|
||||||
Fp8MoeBackend.MARLIN,
|
Fp8MoeBackend.MARLIN,
|
||||||
|
Fp8MoeBackend.XPU,
|
||||||
]
|
]
|
||||||
|
|
||||||
# NOTE(rob): We need to peak into the P/F selection to determine
|
# NOTE(rob): We need to peak into the P/F selection to determine
|
||||||
@@ -393,6 +402,7 @@ def convert_to_fp8_moe_kernel_format(
|
|||||||
Fp8MoeBackend.BATCHED_TRITON,
|
Fp8MoeBackend.BATCHED_TRITON,
|
||||||
Fp8MoeBackend.VLLM_CUTLASS,
|
Fp8MoeBackend.VLLM_CUTLASS,
|
||||||
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
|
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
|
||||||
|
Fp8MoeBackend.XPU,
|
||||||
]:
|
]:
|
||||||
raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")
|
raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,16 @@ import torch
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEConfig,
|
||||||
FusedMoEParallelConfig,
|
FusedMoEParallelConfig,
|
||||||
|
FusedMoEQuantConfig,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceNoOP,
|
TopKWeightAndReduceNoOP,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
|
kFp8DynamicTensorSym,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -20,6 +23,21 @@ if current_platform.is_xpu():
|
|||||||
|
|
||||||
|
|
||||||
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
moe_config: FusedMoEConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
max_num_tokens: int | None = None,
|
||||||
|
num_dispatchers: int | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
moe_config,
|
||||||
|
quant_config,
|
||||||
|
max_num_tokens,
|
||||||
|
num_dispatchers,
|
||||||
|
)
|
||||||
|
self.is_fp8 = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expects_unquantized_inputs(self) -> bool:
|
def expects_unquantized_inputs(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -49,10 +67,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
weight_key: QuantKey | None,
|
weight_key: QuantKey | None,
|
||||||
activation_key: QuantKey | None,
|
activation_key: QuantKey | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# TODO: dispatch based on device.
|
|
||||||
SUPPORTED_W_A = [
|
SUPPORTED_W_A = [
|
||||||
(None, None),
|
(None, None),
|
||||||
(kFp8StaticTensorSym, None),
|
(kFp8StaticTensorSym, None),
|
||||||
|
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
|
||||||
]
|
]
|
||||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||||
|
|
||||||
@@ -103,10 +121,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
xpu_fused_moe(
|
xpu_fused_moe(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w13=w1,
|
w13=w1,
|
||||||
w13_scales=a1q_scale,
|
w13_scales=self.w1_scale,
|
||||||
w13_bias=self.w1_bias,
|
w13_bias=self.w1_bias,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scales=a2_scale,
|
w2_scales=self.w2_scale,
|
||||||
w2_bias=self.w2_bias,
|
w2_bias=self.w2_bias,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
@@ -116,5 +134,22 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
ep_rank=self.moe_config.ep_rank,
|
ep_rank=self.moe_config.ep_rank,
|
||||||
ep_size=self.moe_config.ep_size,
|
ep_size=self.moe_config.ep_size,
|
||||||
output=output,
|
output=output,
|
||||||
|
is_fp8=self.is_fp8,
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
|
class XPUExpertsFp8(XPUExperts):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
moe_config: FusedMoEConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
max_num_tokens: int | None = None,
|
||||||
|
num_dispatchers: int | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
moe_config,
|
||||||
|
quant_config,
|
||||||
|
max_num_tokens,
|
||||||
|
num_dispatchers,
|
||||||
|
)
|
||||||
|
self.is_fp8 = True
|
||||||
|
|||||||
Reference in New Issue
Block a user