[Feat] Enable CompressedTensorW4A8Int for XPU (#37207)
Signed-off-by: Li, Tianmu <tianmu.li@intel.com>
This commit is contained in:
@@ -37,6 +37,26 @@ if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
|
||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a8"):
|
||||
|
||||
@register_fake("_xpu_C::int4_gemm_w4a8")
|
||||
def _int4_gemm_w4a8_fake(
|
||||
input: torch.Tensor,
|
||||
input_scales: torch.Tensor,
|
||||
input_zero_points: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_zp: torch.Tensor,
|
||||
group_size: int,
|
||||
g_idx: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
M = input_2d.size(0)
|
||||
N = q_weight.size(1)
|
||||
return torch.empty((M, N), dtype=torch.float16, device=input.device)
|
||||
|
||||
|
||||
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
|
||||
|
||||
@register_fake("_xpu_C::int4_gemm_w4a16")
|
||||
@@ -87,6 +107,40 @@ _OPS_REGISTERED = False
|
||||
|
||||
|
||||
class xpu_ops:
|
||||
@staticmethod
|
||||
@torch.compile
|
||||
def dynamic_per_token_int8_quant_ref(
|
||||
input: torch.Tensor, use_sym_quant: bool, bits: int
|
||||
):
|
||||
original_sizes = input.size()
|
||||
# view is not safe in torch.compile if input is not contiguous
|
||||
input = input.reshape(
|
||||
-1, original_sizes[-1]
|
||||
) # Flatten except for the last dimension
|
||||
qmin = -(2 ** (bits - 1)) if use_sym_quant else 0
|
||||
qmax = 2 ** (bits - 1) - 1 if use_sym_quant else 2**bits - 1
|
||||
min_val = torch.min(input, dim=-1)[0].to(dtype=torch.float32).unsqueeze(-1)
|
||||
max_val = torch.max(input, dim=-1)[0].to(dtype=torch.float32).unsqueeze(-1)
|
||||
if use_sym_quant:
|
||||
scale = (
|
||||
torch.maximum(torch.abs(min_val), torch.abs(max_val)) / qmax
|
||||
).clamp(min=1e-5)
|
||||
zero_point = torch.zeros_like(scale).to(dtype=torch.int32)
|
||||
else:
|
||||
scale = ((max_val - min_val) / qmax).clamp(min=1e-5)
|
||||
zero_point = -1 * torch.round(min_val / scale).to(dtype=torch.int32)
|
||||
scale = scale.to(dtype=input.dtype)
|
||||
quantized = torch.clamp(
|
||||
torch.round(input / scale.to(dtype=torch.float32) + zero_point),
|
||||
qmin,
|
||||
qmax,
|
||||
).to(dtype=torch.int8 if use_sym_quant else torch.uint8)
|
||||
return (
|
||||
quantized.view(original_sizes),
|
||||
scale.view(original_sizes[:-1] + (1,)),
|
||||
zero_point.view(original_sizes[:-1] + (1,)),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -48,6 +48,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUW4A8IntLinearKernel,
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm import (
|
||||
@@ -138,6 +139,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUW4A8IntLinearKernel,
|
||||
XPUwNa16LinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
@@ -391,5 +393,6 @@ __all__ = [
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUW4A8IntLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUW4A8IntLinearKernel,
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
|
||||
@@ -44,5 +45,6 @@ __all__ = [
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUW4A8IntLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@@ -12,6 +14,8 @@ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_XPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUwNa16LinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
@@ -86,3 +90,112 @@ class XPUwNa16LinearKernel(MPLinearKernel):
|
||||
layer.g_idx,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class XPUW4A8IntLinearKernel(MPLinearKernel):
|
||||
"""XPU kernel for W4A8 integer quantization using oneDNN int4_gemm_w4a8.
|
||||
|
||||
Weights are symmetric group-quantized int4 packed as uint4.
|
||||
Activations are dynamically quantized per-token to symmetric int8.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_xpu():
|
||||
return False, "XPUW4A8Int only supported on XPU"
|
||||
if c.act_type not in (torch.bfloat16, torch.float16):
|
||||
return False, "XPUW4A8Int requires BF16/FP16 activations"
|
||||
if c.weight_type != scalar_types.int4:
|
||||
return (
|
||||
False,
|
||||
f"XPUW4A8Int requires int4 weights, got {c.weight_type}",
|
||||
)
|
||||
if c.zero_points:
|
||||
return False, "XPUW4A8Int only supports symmetric weight quantization"
|
||||
if c.group_size != -1 and c.group_size % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by XPUW4A8Int, "
|
||||
"must be a multiple of 32",
|
||||
)
|
||||
in_size, out_size = c.partition_weight_shape
|
||||
if in_size % 8 != 0 or out_size % 8 != 0:
|
||||
return (
|
||||
False,
|
||||
f"in/out sizes ({in_size}, {out_size}) must be multiples of 8",
|
||||
)
|
||||
|
||||
if c.act_type != torch.float16:
|
||||
logger.warning_once(
|
||||
"XPUW4A8IntLinearKernel is running with model dtype %s, "
|
||||
"but int4_gemm_w4a8 produces float16 output. Recommend "
|
||||
"setting --dtype float16 for best performance.",
|
||||
c.act_type,
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def _pack_int4_weight(self, w: torch.Tensor) -> torch.Tensor:
|
||||
# w is [N, K] int8 with values in [-8, 7]
|
||||
w_u4 = w.to(torch.int32) + 8 # shift to [0, 15]
|
||||
w_u4 = w_u4.reshape(w.shape[0], w.shape[1] // 8, 8) # [N, K/8, 8]
|
||||
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device)
|
||||
packed = ((w_u4 & 0xF) << shifts[None, None, :]).sum(dim=2).to(torch.int32)
|
||||
return packed
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight_scale.data = layer.weight_scale.data.t().contiguous()
|
||||
|
||||
device = layer.weight_packed.device
|
||||
# TODO: support asymmetric quantization
|
||||
weight_zero_point = torch.tensor([8], dtype=torch.int8, device=device)
|
||||
layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)
|
||||
|
||||
# weight_packed is [out, in] int8, signed int4 values in [-8, 7]
|
||||
w = layer.weight_packed.data # [out, in]
|
||||
|
||||
# TODO: implement asym case
|
||||
packed = self._pack_int4_weight(w) # [out, in/8] packed uint4
|
||||
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_q_name,
|
||||
torch.nn.Parameter(packed, requires_grad=False),
|
||||
)
|
||||
|
||||
# Free the original unpacked int8 weight (still registered as "weight")
|
||||
# to avoid double-storing both int8 [N, K] and int32 [N, K/8] in memory.
|
||||
layer.register_parameter("weight", None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1]) # [M, K]
|
||||
from vllm._xpu_ops import xpu_ops as ops
|
||||
|
||||
# TODO: static and asymmetric quantization case
|
||||
# Common code for CompressedTensorsW4A8Int does not read act symmetry data
|
||||
quant_x, x_scale, x_zero = ops.dynamic_per_token_int8_quant_ref(
|
||||
reshaped_x, True, 8
|
||||
)
|
||||
|
||||
out = torch.ops._xpu_C.int4_gemm_w4a8(
|
||||
quant_x,
|
||||
x_scale,
|
||||
x_zero,
|
||||
layer.weight_packed.t(),
|
||||
layer.weight_scale,
|
||||
layer.weight_zero_point,
|
||||
self.config.group_size,
|
||||
None, # g_idx not currently supported
|
||||
bias,
|
||||
)
|
||||
|
||||
return out.to(x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user