From 9bdb06b4368e304bc5e23c8df2dff8f8b2ccf0f6 Mon Sep 17 00:00:00 2001 From: zofia <110436990+zufangzhu@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:17:35 +0800 Subject: [PATCH] [XPU][6/N] add xpu scaled_mm kernel (#34117) Signed-off-by: Zhu, Zufang --- .../scripts/hardware_ci/run-xpu-test.sh | 1 + .../model_executor/layers/quantization/fp8.py | 11 +--- .../kernels/scaled_mm/__init__.py | 6 ++ .../quantization/kernels/scaled_mm/xpu.py | 59 +++++++++++++++++++ 4 files changed, 67 insertions(+), 10 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 56676ee28..b52dd7826 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -39,6 +39,7 @@ docker run \ python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8 python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a61239706..80348edcc 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -180,18 +180,9 @@ class Fp8Config(QuantizationConfig): weight_block_size=weight_block_size, ) - def get_xpu_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> "QuantizeMethodBase | None": - raise NotImplementedError( - "FP8 quantization is not supported during xpu kernel migration." - ) - def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": - if current_platform.is_xpu(): - return self.get_xpu_quant_method(layer, prefix) if isinstance(layer, LinearBase): if is_layer_skipped( prefix=prefix, @@ -300,7 +291,7 @@ class Fp8LinearMethod(LinearMethodBase): or envs.VLLM_TEST_FORCE_FP8_MARLIN ) # Disable marlin for rocm - if current_platform.is_rocm(): + if current_platform.is_rocm() or current_platform.is_xpu(): self.use_marlin = False if vllm_is_batch_invariant(): self.use_marlin = False diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index e5401ff81..bbd43dd10 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -39,6 +39,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonInt8ScaledMMLinearKernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu import ( + XPUFP8ScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms import PlatformEnum, current_platform @@ -72,6 +75,9 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = PerTensorTorchFP8ScaledMMLinearKernel, ChannelWiseTorchFP8ScaledMMLinearKernel, ], + PlatformEnum.XPU: [ + XPUFP8ScaledMMLinearKernel, + ], } _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py new file mode 100644 index 000000000..5b816a3f5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +import torch + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) +from vllm.platforms import current_platform + + +class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_xpu(): + return False, "XPUFP8ScaledMM only support on XPU" + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}: + return False, "XPUFP8ScaledMM only support FP8 weight dtype" + return True, None + + def __init__( + self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] + ) -> None: + assert self.can_implement(c)[0] + assert self.is_supported()[0] + self.config = c + self.layer_param_names = layer_param_names + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + weight = layer.weight + weight_scale = layer.weight_scale + return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias) + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + pass