[Feat] Enable CompressedTensorW4A8Int for XPU (#37207)

Signed-off-by: Li, Tianmu <tianmu.li@intel.com>
This commit is contained in:
Tianmu Li
2026-03-19 19:34:28 -07:00
committed by GitHub
parent 269bf46d99
commit 47b7af0d87
4 changed files with 172 additions and 0 deletions

View File

@@ -37,6 +37,26 @@ if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
return torch.empty((M, N), dtype=input.dtype, device=input.device)
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a8"):
@register_fake("_xpu_C::int4_gemm_w4a8")
def _int4_gemm_w4a8_fake(
input: torch.Tensor,
input_scales: torch.Tensor,
input_zero_points: torch.Tensor,
q_weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
group_size: int,
g_idx: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
input_2d = input.view(-1, input.shape[-1])
M = input_2d.size(0)
N = q_weight.size(1)
return torch.empty((M, N), dtype=torch.float16, device=input.device)
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
@register_fake("_xpu_C::int4_gemm_w4a16")
@@ -87,6 +107,40 @@ _OPS_REGISTERED = False
class xpu_ops:
@staticmethod
@torch.compile
def dynamic_per_token_int8_quant_ref(
input: torch.Tensor, use_sym_quant: bool, bits: int
):
original_sizes = input.size()
# view is not safe in torch.compile if input is not contiguous
input = input.reshape(
-1, original_sizes[-1]
) # Flatten except for the last dimension
qmin = -(2 ** (bits - 1)) if use_sym_quant else 0
qmax = 2 ** (bits - 1) - 1 if use_sym_quant else 2**bits - 1
min_val = torch.min(input, dim=-1)[0].to(dtype=torch.float32).unsqueeze(-1)
max_val = torch.max(input, dim=-1)[0].to(dtype=torch.float32).unsqueeze(-1)
if use_sym_quant:
scale = (
torch.maximum(torch.abs(min_val), torch.abs(max_val)) / qmax
).clamp(min=1e-5)
zero_point = torch.zeros_like(scale).to(dtype=torch.int32)
else:
scale = ((max_val - min_val) / qmax).clamp(min=1e-5)
zero_point = -1 * torch.round(min_val / scale).to(dtype=torch.int32)
scale = scale.to(dtype=input.dtype)
quantized = torch.clamp(
torch.round(input / scale.to(dtype=torch.float32) + zero_point),
qmin,
qmax,
).to(dtype=torch.int8 if use_sym_quant else torch.uint8)
return (
quantized.view(original_sizes),
scale.view(original_sizes[:-1] + (1,)),
zero_point.view(original_sizes[:-1] + (1,)),
)
@staticmethod
def flash_attn_varlen_func(
q: torch.Tensor,

View File

@@ -48,6 +48,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUW4A8IntLinearKernel,
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm import (
@@ -138,6 +139,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
ExllamaLinearKernel,
],
PlatformEnum.XPU: [
XPUW4A8IntLinearKernel,
XPUwNa16LinearKernel,
],
PlatformEnum.CPU: [
@@ -391,5 +393,6 @@ __all__ = [
"ExllamaLinearKernel",
"MacheteLinearKernel",
"MarlinLinearKernel",
"XPUW4A8IntLinearKernel",
"XPUwNa16LinearKernel",
]

View File

@@ -30,6 +30,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
MPLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUW4A8IntLinearKernel,
XPUwNa16LinearKernel,
)
@@ -44,5 +45,6 @@ __all__ = [
"ExllamaLinearKernel",
"MacheteLinearKernel",
"MarlinLinearKernel",
"XPUW4A8IntLinearKernel",
"XPUwNa16LinearKernel",
]

View File

@@ -5,6 +5,8 @@
import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@@ -12,6 +14,8 @@ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_XPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
logger = init_logger(__name__)
class XPUwNa16LinearKernel(MPLinearKernel):
@classmethod
@@ -86,3 +90,112 @@ class XPUwNa16LinearKernel(MPLinearKernel):
layer.g_idx,
)
return out
class XPUW4A8IntLinearKernel(MPLinearKernel):
"""XPU kernel for W4A8 integer quantization using oneDNN int4_gemm_w4a8.
Weights are symmetric group-quantized int4 packed as uint4.
Activations are dynamically quantized per-token to symmetric int8.
"""
@classmethod
def get_min_capability(cls) -> int:
return -1
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_xpu():
return False, "XPUW4A8Int only supported on XPU"
if c.act_type not in (torch.bfloat16, torch.float16):
return False, "XPUW4A8Int requires BF16/FP16 activations"
if c.weight_type != scalar_types.int4:
return (
False,
f"XPUW4A8Int requires int4 weights, got {c.weight_type}",
)
if c.zero_points:
return False, "XPUW4A8Int only supports symmetric weight quantization"
if c.group_size != -1 and c.group_size % 32 != 0:
return (
False,
f"Group size ({c.group_size}) not supported by XPUW4A8Int, "
"must be a multiple of 32",
)
in_size, out_size = c.partition_weight_shape
if in_size % 8 != 0 or out_size % 8 != 0:
return (
False,
f"in/out sizes ({in_size}, {out_size}) must be multiples of 8",
)
if c.act_type != torch.float16:
logger.warning_once(
"XPUW4A8IntLinearKernel is running with model dtype %s, "
"but int4_gemm_w4a8 produces float16 output. Recommend "
"setting --dtype float16 for best performance.",
c.act_type,
)
return True, None
def _pack_int4_weight(self, w: torch.Tensor) -> torch.Tensor:
# w is [N, K] int8 with values in [-8, 7]
w_u4 = w.to(torch.int32) + 8 # shift to [0, 15]
w_u4 = w_u4.reshape(w.shape[0], w.shape[1] // 8, 8) # [N, K/8, 8]
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device)
packed = ((w_u4 & 0xF) << shifts[None, None, :]).sum(dim=2).to(torch.int32)
return packed
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_scale.data = layer.weight_scale.data.t().contiguous()
device = layer.weight_packed.device
# TODO: support asymmetric quantization
weight_zero_point = torch.tensor([8], dtype=torch.int8, device=device)
layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)
# weight_packed is [out, in] int8, signed int4 values in [-8, 7]
w = layer.weight_packed.data # [out, in]
# TODO: implement asym case
packed = self._pack_int4_weight(w) # [out, in/8] packed uint4
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(packed, requires_grad=False),
)
# Free the original unpacked int8 weight (still registered as "weight")
# to avoid double-storing both int8 [N, K] and int32 [N, K/8] in memory.
layer.register_parameter("weight", None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1]) # [M, K]
from vllm._xpu_ops import xpu_ops as ops
# TODO: static and asymmetric quantization case
# Common code for CompressedTensorsW4A8Int does not read act symmetry data
quant_x, x_scale, x_zero = ops.dynamic_per_token_int8_quant_ref(
reshaped_x, True, 8
)
out = torch.ops._xpu_C.int4_gemm_w4a8(
quant_x,
x_scale,
x_zero,
layer.weight_packed.t(),
layer.weight_scale,
layer.weight_zero_point,
self.config.group_size,
None, # g_idx not currently supported
bias,
)
return out.to(x.dtype)