[XPU][6/N] add xpu scaled_mm kernel (#34117)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
@@ -39,6 +39,7 @@ docker run \
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8
|
||||
python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager
|
||||
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
||||
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
||||
|
||||
@@ -180,18 +180,9 @@ class Fp8Config(QuantizationConfig):
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
|
||||
def get_xpu_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
raise NotImplementedError(
|
||||
"FP8 quantization is not supported during xpu kernel migration."
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if current_platform.is_xpu():
|
||||
return self.get_xpu_quant_method(layer, prefix)
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
@@ -300,7 +291,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
self.use_marlin = False
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
@@ -39,6 +39,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu import (
|
||||
XPUFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
@@ -72,6 +75,9 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUFP8ScaledMMLinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_xpu():
|
||||
return False, "XPUFP8ScaledMM only support on XPU"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
|
||||
return True, None
|
||||
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
assert self.can_implement(c)[0]
|
||||
assert self.is_supported()[0]
|
||||
self.config = c
|
||||
self.layer_param_names = layer_param_names
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
pass
|
||||
Reference in New Issue
Block a user