[XPU][6/N] add xpu scaled_mm kernel (#34117)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
(cherry picked from commit 9bdb06b436)
This commit is contained in:
@@ -39,6 +39,7 @@ docker run \
|
|||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
|
||||||
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8
|
||||||
python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager
|
python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager
|
||||||
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
||||||
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
||||||
|
|||||||
@@ -180,18 +180,9 @@ class Fp8Config(QuantizationConfig):
|
|||||||
weight_block_size=weight_block_size,
|
weight_block_size=weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_xpu_quant_method(
|
|
||||||
self, layer: torch.nn.Module, prefix: str
|
|
||||||
) -> "QuantizeMethodBase | None":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FP8 quantization is not supported during xpu kernel migration."
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> "QuantizeMethodBase | None":
|
) -> "QuantizeMethodBase | None":
|
||||||
if current_platform.is_xpu():
|
|
||||||
return self.get_xpu_quant_method(layer, prefix)
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(
|
if is_layer_skipped(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@@ -300,7 +291,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||||
)
|
)
|
||||||
# Disable marlin for rocm
|
# Disable marlin for rocm
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
if vllm_is_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
|||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||||
TritonInt8ScaledMMLinearKernel,
|
TritonInt8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu import (
|
||||||
|
XPUFP8ScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
from vllm.platforms import PlatformEnum, current_platform
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
|
|
||||||
@@ -72,6 +75,9 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
|||||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||||
],
|
],
|
||||||
|
PlatformEnum.XPU: [
|
||||||
|
XPUFP8ScaledMMLinearKernel,
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
|
FP8ScaledMMLinearKernel,
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
|
@classmethod
|
||||||
|
def is_supported(
|
||||||
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
if not current_platform.is_xpu():
|
||||||
|
return False, "XPUFP8ScaledMM only support on XPU"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||||
|
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||||
|
) -> None:
|
||||||
|
assert self.can_implement(c)[0]
|
||||||
|
assert self.is_supported()[0]
|
||||||
|
self.config = c
|
||||||
|
self.layer_param_names = layer_param_names
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
weight = layer.weight
|
||||||
|
weight_scale = layer.weight_scale
|
||||||
|
return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)
|
||||||
|
|
||||||
|
def apply_scaled_mm(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None,
|
||||||
|
output_shape: list,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user