[XPU][1/N] Deprecate ipex and switch to vllm-xpu-kernels for xpu platform (#33379)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -129,12 +129,8 @@ class SiluAndMul(CustomOp):
|
||||
|
||||
def __init__(self, *, compile_native: bool = True):
|
||||
super().__init__(compile_native=compile_native)
|
||||
if current_platform.is_cuda_alike():
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
self.op = torch.ops._C.silu_and_mul
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
self.op = ipex_ops.silu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
@@ -152,11 +148,7 @@ class SiluAndMul(CustomOp):
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:mul_and_silu]
|
||||
@@ -175,12 +167,8 @@ class MulAndSilu(CustomOp):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike():
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
self.op = torch.ops._C.mul_and_silu
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
self.op = ipex_ops.silu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
@@ -196,8 +184,8 @@ class MulAndSilu(CustomOp):
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
# TODO implement forward_xpu for MulAndSilu
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:gelu_and_mul_sparse]
|
||||
@@ -278,7 +266,11 @@ class GeluAndMul(CustomOp):
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
if approximate == "none":
|
||||
self.op = torch.ops._C.gelu_and_mul
|
||||
elif approximate == "tanh":
|
||||
@@ -289,13 +281,6 @@ class GeluAndMul(CustomOp):
|
||||
"with torch.compile. For native implementation, fallback to 'none' "
|
||||
"approximation. The custom kernel implementation is unaffected."
|
||||
)
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
if approximate == "none":
|
||||
self.op = ipex_ops.gelu_and_mul
|
||||
else:
|
||||
self.op = ipex_ops.gelu_tanh_and_mul
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
@@ -314,11 +299,7 @@ class GeluAndMul(CustomOp):
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
return self.forward_cuda(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"approximate={repr(self.approximate)}"
|
||||
@@ -401,12 +382,12 @@ class NewGELU(CustomOp):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
self.op = torch.ops._C.gelu_new
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
self.op = ipex_ops.gelu_new
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
@@ -419,7 +400,7 @@ class NewGELU(CustomOp):
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.op(x)
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:gelu_fast]
|
||||
@@ -429,12 +410,12 @@ class FastGELU(CustomOp):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
self.op = torch.ops._C.gelu_fast
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
self.op = ipex_ops.gelu_fast
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
@@ -446,7 +427,7 @@ class FastGELU(CustomOp):
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.op(x)
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:quick_gelu]
|
||||
@@ -457,12 +438,12 @@ class QuickGELU(CustomOp):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
self.op = torch.ops._C.gelu_quick
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
self.op = ipex_ops.gelu_quick
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
@@ -474,12 +455,7 @@ class QuickGELU(CustomOp):
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
# TODO implement forward_xpu for QuickGELU
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:relu2]
|
||||
|
||||
@@ -231,24 +231,7 @@ class RMSNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
if residual is not None:
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
return ops.rms_norm(
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return self.forward_cuda(x, residual)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"hidden_size={self.weight.data.size(0)}"
|
||||
|
||||
@@ -60,8 +60,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"ModelOptFp8LinearMethod",
|
||||
"ModelOptFp8PcPtLinearMethod",
|
||||
"ModelOptFp8PbWoLinearMethod",
|
||||
"IPEXAWQLinearMethod",
|
||||
"IPEXGPTQLinearMethod",
|
||||
"QuarkLinearMethod",
|
||||
"ModelOptNvFp4LinearMethod",
|
||||
"PetitNvFp4LinearMethod",
|
||||
|
||||
@@ -24,7 +24,6 @@ QuantizationMethods = Literal[
|
||||
"compressed-tensors",
|
||||
"bitsandbytes",
|
||||
"experts_int8",
|
||||
"ipex",
|
||||
"quark",
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
@@ -41,7 +40,6 @@ DEPRECATED_QUANTIZATION_METHODS = [
|
||||
"fbgemm_fp8",
|
||||
"fp_quant",
|
||||
"experts_int8",
|
||||
"ipex",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
|
||||
@@ -121,7 +119,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .inc import INCConfig
|
||||
from .ipex_quant import IPEXConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
@@ -144,7 +141,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"ptpc_fp8": PTPCFp8Config,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"ipex": IPEXConfig,
|
||||
"quark": QuarkConfig,
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
|
||||
@@ -184,39 +184,10 @@ class Fp8Config(QuantizationConfig):
|
||||
def get_xpu_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||
XPUFp8LinearMethod,
|
||||
XPUFp8MoEMethod,
|
||||
raise NotImplementedError(
|
||||
"FP8 quantization is not supported during xpu kernel migration."
|
||||
)
|
||||
|
||||
fp8_config = Fp8Config(
|
||||
is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
|
||||
activation_scheme=self.activation_scheme,
|
||||
ignored_layers=self.ignored_layers,
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return XPUFp8LinearMethod(fp8_config)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
return XPUFp8MoEMethod(fp8_config, layer)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
|
||||
@@ -38,7 +38,6 @@ class INCConfig(QuantizationConfig):
|
||||
"awq",
|
||||
"awq:marlin",
|
||||
"marlin",
|
||||
"ipex",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -410,31 +409,10 @@ class INCConfig(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||
IPEXAWQLinearMethod,
|
||||
IPEXConfig,
|
||||
IPEXGPTQLinearMethod,
|
||||
raise NotImplementedError(
|
||||
"INC quantization is not supported during xpu kernel migration."
|
||||
)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if "awq" in self.packing_format:
|
||||
config = IPEXConfig(
|
||||
method="awq", weight_bits=weight_bits, group_size=group_size
|
||||
)
|
||||
return IPEXAWQLinearMethod(config)
|
||||
elif "gptq" in self.packing_format:
|
||||
config = IPEXConfig(
|
||||
method="gptq", weight_bits=weight_bits, group_size=group_size
|
||||
)
|
||||
return IPEXGPTQLinearMethod(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ipex backend only supports awq "
|
||||
f"and gptq format,but got {self.packing_format}"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
||||
if prefix and self.extra_config:
|
||||
for layer_name in self.extra_config:
|
||||
|
||||
@@ -1,403 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8LinearMethod,
|
||||
Fp8OnlineMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MIN_IPEX_VERSION = "2.6.0"
|
||||
|
||||
|
||||
class IPEXConfig(QuantizationConfig):
|
||||
"""INT8 quantization config class using IPEX for the CPU/XPU backend,
|
||||
including AWQ, GPTQ.
|
||||
"""
|
||||
|
||||
IPEX_QUANT_METHOD_MAP = {
|
||||
"awq": 1,
|
||||
"gptq": 0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
desc_act: bool | None = None,
|
||||
lm_head_quantized: bool | None = None,
|
||||
is_sym: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.method = method
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.is_sym = is_sym
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
if self.weight_bits not in [4]:
|
||||
raise ValueError(
|
||||
f"IPEX quantization supports weight bits [4], "
|
||||
f"but got {self.weight_bits}."
|
||||
)
|
||||
|
||||
if self.method not in ["awq", "gptq"]:
|
||||
raise ValueError(
|
||||
f"IPEX quantization supports [awq, gptq], but got {self.method}."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"IPEXConfig(method={self.method},"
|
||||
f"weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ipex"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
|
||||
method = cls.get_from_keys(config, ["quant_method"]).lower()
|
||||
if method == "awq":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
is_sym = not cls.get_from_keys_or(config, ["zero_point"], default=False)
|
||||
return cls(
|
||||
method,
|
||||
weight_bits,
|
||||
group_size,
|
||||
modules_to_not_convert,
|
||||
False,
|
||||
False,
|
||||
is_sym,
|
||||
)
|
||||
# otherwise for gptq
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
|
||||
is_sym = cls.get_from_keys_or(config, ["sym"], default=True)
|
||||
return cls(
|
||||
method, weight_bits, group_size, [], desc_act, lm_head_quantized, is_sym
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
if not current_platform.is_xpu():
|
||||
return None
|
||||
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
if quant_method in ["awq", "gptq"]:
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "LinearMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.method == "awq":
|
||||
if is_layer_skipped(
|
||||
prefix,
|
||||
self.modules_to_not_convert,
|
||||
self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return IPEXAWQLinearMethod(self)
|
||||
if self.method == "gptq":
|
||||
return IPEXGPTQLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class IPEXGPTQLinearMethod(GPTQLinearMethod):
|
||||
"""GPTQ linear method using IPEX for the CPU/XPU backend."""
|
||||
|
||||
def __init__(self, quant_config: IPEXConfig):
|
||||
self.quant_config = quant_config # type: ignore
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
bias = layer.bias if not layer.skip_bias_add else None
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION):
|
||||
raise ImportError(
|
||||
"intel_extension_for_pytorch version is "
|
||||
"wrong. Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
|
||||
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
|
||||
" to use IPEX-AWQ linear method."
|
||||
) from err
|
||||
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
|
||||
# with better performance.
|
||||
lowp_mode = ipex.quantization.WoqLowpMode.INT8
|
||||
# The weight will be de-packed from INT4 to INT8.
|
||||
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
|
||||
# The float activation will be quantized (dynamic, per-token) to INT8.
|
||||
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK
|
||||
|
||||
assert isinstance(self.quant_config, IPEXConfig)
|
||||
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
||||
weight_dtype=weight_dtype,
|
||||
lowp_mode=lowp_mode,
|
||||
act_quant_mode=act_quant_mode,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
layer.ipex_output_size = layer.qweight.shape[-1]
|
||||
g_idx = layer.g_idx if self.quant_config.desc_act else None
|
||||
layer.ipex_qlinear = (
|
||||
ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
layer.qweight.size(0),
|
||||
layer.ipex_output_size,
|
||||
qconfig=qconfig,
|
||||
g_idx=g_idx,
|
||||
bias=bias,
|
||||
group_size=self.quant_config.group_size,
|
||||
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"],
|
||||
weight_qscheme="sym" if self.quant_config.is_sym else "asym",
|
||||
)
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
return out.reshape(x.shape[:-1] + (layer.ipex_output_size,))
|
||||
|
||||
|
||||
class IPEXAWQLinearMethod(AWQLinearMethod):
|
||||
"""AWQ linear method using IPEX for the CPU/XPU backend."""
|
||||
|
||||
def __init__(self, quant_config: IPEXConfig):
|
||||
self.quant_config = quant_config # type: ignore
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer=layer)
|
||||
|
||||
bias = layer.bias if not layer.skip_bias_add else None
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION):
|
||||
raise ImportError(
|
||||
"intel_extension_for_pytorch version is "
|
||||
"wrong. Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
|
||||
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
|
||||
" to use IPEX-AWQ linear method."
|
||||
) from err
|
||||
|
||||
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
|
||||
# with better performance.
|
||||
lowp_mode = ipex.quantization.WoqLowpMode.INT8
|
||||
# The weight will be de-packed from INT4 to INT8.
|
||||
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
|
||||
# The float activation will be quantized (dynamic, per-token) to INT8.
|
||||
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH
|
||||
|
||||
assert isinstance(self.quant_config, IPEXConfig)
|
||||
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
||||
weight_dtype=weight_dtype,
|
||||
lowp_mode=lowp_mode,
|
||||
act_quant_mode=act_quant_mode,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
layer.ipex_output_size = layer.qweight.size(1) * self.quant_config.pack_factor
|
||||
layer.ipex_qlinear = (
|
||||
ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
layer.qweight.size(0),
|
||||
layer.ipex_output_size,
|
||||
qconfig=qconfig,
|
||||
bias=bias,
|
||||
group_size=self.quant_config.group_size,
|
||||
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"], # type: ignore
|
||||
weight_qscheme="sym" if self.quant_config.is_sym else "asym",
|
||||
)
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
return out.reshape(x.shape[:-1] + (layer.ipex_output_size,))
|
||||
|
||||
|
||||
class XPUFp8LinearMethod(Fp8LinearMethod):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
# Update the layer with the new values.
|
||||
replace_parameter(layer, "weight", qweight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
layer.input_scale = None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
weight = layer.weight.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
output = torch.ops.torch_ipex.fp8_gemm_w8a16(
|
||||
x, weight, True, weight_scale, bias
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(quant_config, layer)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
layer.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
w1_scale_inv=layer.w13_weight_scale,
|
||||
w2_scale_inv=layer.w2_weight_scale,
|
||||
a1_scale_inv=layer.w13_input_scale,
|
||||
a2_scale_inv=layer.w2_input_scale,
|
||||
use_prepack=True,
|
||||
experts_start_id=ep_rank_start,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
)
|
||||
@@ -232,17 +232,14 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
if key is None:
|
||||
# XPU kernel doesn't support key=None so fall back to native impl
|
||||
# TODO(sarckk): add support for optional key in
|
||||
# ipex.llm.functional.rotary_embedding_batched
|
||||
return self.forward_native(positions, query, key)
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
|
||||
Reference in New Issue
Block a user