[XPU][5/N] add wna16 xpu kernel (#33973)

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
zofia
2026-02-06 23:59:53 +08:00
committed by GitHub
parent cd8b405bd0
commit 2ce9fe4ad0
2 changed files with 57 additions and 65 deletions

View File

@@ -39,6 +39,7 @@ docker run \
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
cd tests

View File

@@ -3,88 +3,71 @@
import torch
from torch.nn.parameter import Parameter
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_XPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
class XPUwNa16LinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 0
return -1
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_xpu():
return False, "IPEX wNa16 only supported on XPU/CPU devices"
return False, "XPUwNa16 only supported on XPU"
# TODO: (yiliu30) relax these restrictions in later PRs
if c.zero_points:
return False, "Zero points not supported for Now"
if c.act_type != torch.bfloat16 and c.act_type != torch.float16:
return False, "XPUwNa16 only supports BF16/FP16 activations"
if c.weight_type not in _XPUWNA16_SUPPORTED_QUANT_TYPES:
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"XPUwNa16, supported types are: "
f"{_XPUWNA16_SUPPORTED_QUANT_TYPES}",
)
if c.group_size != -1 and c.group_size % 32 != 0:
return (
False,
f"Group size ({c.group_size}) not supported by "
"XPUwNa16, supported group sizes are multiples of 32",
)
if c.partition_weight_shape[0] % 32 != 0:
return (
False,
f"Input size ({c.partition_weight_shape[0]}) not supported by "
"XPUwNa16, supported sizes are multiples of 32",
)
if c.partition_weight_shape[1] % 32 != 0:
return (
False,
f"Output size ({c.partition_weight_shape[1]}) not supported by "
"XPUWNA16, supported sizes are multiples of 32",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from packaging import version
def process_weights_after_loading(self, layer: torch.nn.Module):
layer.weight_scale.data = layer.weight_scale.t().contiguous()
MIN_IPEX_VERSION = "2.6.0"
bias = layer.bias if not layer.skip_bias_add else None
try:
import intel_extension_for_pytorch as ipex
if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION):
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}."
)
except ImportError as err:
raise ImportError(
"Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method."
) from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode = ipex.quantization.WoqLowpMode.INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
act_quant_mode=act_quant_mode,
group_size=self.config.group_size,
weight_qscheme=ipex.quantization.WoqWeightQScheme.SYMMETRIC,
)
qweight = layer.weight_packed
g_idx = layer.weight_g_idx if self.config.has_g_idx else None
scales = layer.weight_scale
qzeros = None
if self.config.zero_points:
qzeros = layer.weight_zero_point.contiguous()
qweight = qweight.t().contiguous()
scales = scales.t().contiguous()
layer.ipex_output_size = self.config.partition_weight_shape[1]
layer.ipex_qlinear = (
ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight(
qweight,
scales,
qzeros,
in_features=self.config.partition_weight_shape[0],
out_features=self.config.partition_weight_shape[1],
qconfig=qconfig,
g_idx=g_idx,
bias=bias,
group_size=self.config.group_size,
quant_method=0, # `0` stands for the IPEX GPTQ
)
)
layer.weight_zero_point.data = layer.weight_zero_point.t().contiguous()
else:
weight_zero_point = torch.Tensor([8]).to(torch.int8).to("xpu")
layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)
if self.config.has_g_idx:
layer.g_idx.data = layer.g_idx.t().contiguous()
else:
layer.g_idx = None
def apply_weights(
self,
@@ -93,5 +76,13 @@ class XPUwNa16LinearKernel(MPLinearKernel):
bias: torch.Tensor | None = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size,))
out = torch.ops._xpu_C.int4_gemm_w4a16(
reshaped_x,
layer.weight_packed.t(),
bias,
layer.weight_scale,
layer.weight_zero_point,
self.config.group_size,
layer.g_idx,
)
return out