diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index a2eb5ff3a..604f3412e 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -37,6 +37,26 @@ if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"): return torch.empty((M, N), dtype=input.dtype, device=input.device) +if hasattr(torch.ops._xpu_C, "int4_gemm_w4a8"): + + @register_fake("_xpu_C::int4_gemm_w4a8") + def _int4_gemm_w4a8_fake( + input: torch.Tensor, + input_scales: torch.Tensor, + input_zero_points: torch.Tensor, + q_weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + group_size: int, + g_idx: torch.Tensor | None = None, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + input_2d = input.view(-1, input.shape[-1]) + M = input_2d.size(0) + N = q_weight.size(1) + return torch.empty((M, N), dtype=torch.float16, device=input.device) + + if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"): @register_fake("_xpu_C::int4_gemm_w4a16") @@ -87,6 +107,40 @@ _OPS_REGISTERED = False class xpu_ops: + @staticmethod + @torch.compile + def dynamic_per_token_int8_quant_ref( + input: torch.Tensor, use_sym_quant: bool, bits: int + ): + original_sizes = input.size() + # view is not safe in torch.compile if input is not contiguous + input = input.reshape( + -1, original_sizes[-1] + ) # Flatten except for the last dimension + qmin = -(2 ** (bits - 1)) if use_sym_quant else 0 + qmax = 2 ** (bits - 1) - 1 if use_sym_quant else 2**bits - 1 + min_val = torch.min(input, dim=-1)[0].to(dtype=torch.float32).unsqueeze(-1) + max_val = torch.max(input, dim=-1)[0].to(dtype=torch.float32).unsqueeze(-1) + if use_sym_quant: + scale = ( + torch.maximum(torch.abs(min_val), torch.abs(max_val)) / qmax + ).clamp(min=1e-5) + zero_point = torch.zeros_like(scale).to(dtype=torch.int32) + else: + scale = ((max_val - min_val) / qmax).clamp(min=1e-5) + zero_point = -1 * torch.round(min_val / scale).to(dtype=torch.int32) + scale = scale.to(dtype=input.dtype) + quantized = torch.clamp( + torch.round(input / scale.to(dtype=torch.float32) + zero_point), + qmin, + qmax, + ).to(dtype=torch.int8 if use_sym_quant else torch.uint8) + return ( + quantized.view(original_sizes), + scale.view(original_sizes[:-1] + (1,)), + zero_point.view(original_sizes[:-1] + (1,)), + ) + @staticmethod def flash_attn_varlen_func( q: torch.Tensor, diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 79afc8b37..570ce1133 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -48,6 +48,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.marlin import ( MarlinLinearKernel, ) from vllm.model_executor.kernels.linear.mixed_precision.xpu import ( + XPUW4A8IntLinearKernel, XPUwNa16LinearKernel, ) from vllm.model_executor.kernels.linear.scaled_mm import ( @@ -138,6 +139,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { ExllamaLinearKernel, ], PlatformEnum.XPU: [ + XPUW4A8IntLinearKernel, XPUwNa16LinearKernel, ], PlatformEnum.CPU: [ @@ -391,5 +393,6 @@ __all__ = [ "ExllamaLinearKernel", "MacheteLinearKernel", "MarlinLinearKernel", + "XPUW4A8IntLinearKernel", "XPUwNa16LinearKernel", ] diff --git a/vllm/model_executor/kernels/linear/mixed_precision/__init__.py b/vllm/model_executor/kernels/linear/mixed_precision/__init__.py index 32f9afcce..6c144a5ec 100644 --- a/vllm/model_executor/kernels/linear/mixed_precision/__init__.py +++ b/vllm/model_executor/kernels/linear/mixed_precision/__init__.py @@ -30,6 +30,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import ( MPLinearLayerConfig, ) from vllm.model_executor.kernels.linear.mixed_precision.xpu import ( + XPUW4A8IntLinearKernel, XPUwNa16LinearKernel, ) @@ -44,5 +45,6 @@ __all__ = [ "ExllamaLinearKernel", "MacheteLinearKernel", "MarlinLinearKernel", + "XPUW4A8IntLinearKernel", "XPUwNa16LinearKernel", ] diff --git a/vllm/model_executor/kernels/linear/mixed_precision/xpu.py b/vllm/model_executor/kernels/linear/mixed_precision/xpu.py index 983bd7734..78fa7e83c 100644 --- a/vllm/model_executor/kernels/linear/mixed_precision/xpu.py +++ b/vllm/model_executor/kernels/linear/mixed_precision/xpu.py @@ -5,6 +5,8 @@ import torch from torch.nn.parameter import Parameter +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -12,6 +14,8 @@ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig _XPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8) +logger = init_logger(__name__) + class XPUwNa16LinearKernel(MPLinearKernel): @classmethod @@ -86,3 +90,112 @@ class XPUwNa16LinearKernel(MPLinearKernel): layer.g_idx, ) return out + + +class XPUW4A8IntLinearKernel(MPLinearKernel): + """XPU kernel for W4A8 integer quantization using oneDNN int4_gemm_w4a8. + + Weights are symmetric group-quantized int4 packed as uint4. + Activations are dynamically quantized per-token to symmetric int8. + """ + + @classmethod + def get_min_capability(cls) -> int: + return -1 + + @classmethod + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_xpu(): + return False, "XPUW4A8Int only supported on XPU" + if c.act_type not in (torch.bfloat16, torch.float16): + return False, "XPUW4A8Int requires BF16/FP16 activations" + if c.weight_type != scalar_types.int4: + return ( + False, + f"XPUW4A8Int requires int4 weights, got {c.weight_type}", + ) + if c.zero_points: + return False, "XPUW4A8Int only supports symmetric weight quantization" + if c.group_size != -1 and c.group_size % 32 != 0: + return ( + False, + f"Group size ({c.group_size}) not supported by XPUW4A8Int, " + "must be a multiple of 32", + ) + in_size, out_size = c.partition_weight_shape + if in_size % 8 != 0 or out_size % 8 != 0: + return ( + False, + f"in/out sizes ({in_size}, {out_size}) must be multiples of 8", + ) + + if c.act_type != torch.float16: + logger.warning_once( + "XPUW4A8IntLinearKernel is running with model dtype %s, " + "but int4_gemm_w4a8 produces float16 output. Recommend " + "setting --dtype float16 for best performance.", + c.act_type, + ) + + return True, None + + def _pack_int4_weight(self, w: torch.Tensor) -> torch.Tensor: + # w is [N, K] int8 with values in [-8, 7] + w_u4 = w.to(torch.int32) + 8 # shift to [0, 15] + w_u4 = w_u4.reshape(w.shape[0], w.shape[1] // 8, 8) # [N, K/8, 8] + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device) + packed = ((w_u4 & 0xF) << shifts[None, None, :]).sum(dim=2).to(torch.int32) + return packed + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight_scale.data = layer.weight_scale.data.t().contiguous() + + device = layer.weight_packed.device + # TODO: support asymmetric quantization + weight_zero_point = torch.tensor([8], dtype=torch.int8, device=device) + layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False) + + # weight_packed is [out, in] int8, signed int4 values in [-8, 7] + w = layer.weight_packed.data # [out, in] + + # TODO: implement asym case + packed = self._pack_int4_weight(w) # [out, in/8] packed uint4 + + replace_parameter( + layer, + self.w_q_name, + torch.nn.Parameter(packed, requires_grad=False), + ) + + # Free the original unpacked int8 weight (still registered as "weight") + # to avoid double-storing both int8 [N, K] and int32 [N, K/8] in memory. + layer.register_parameter("weight", None) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) # [M, K] + from vllm._xpu_ops import xpu_ops as ops + + # TODO: static and asymmetric quantization case + # Common code for CompressedTensorsW4A8Int does not read act symmetry data + quant_x, x_scale, x_zero = ops.dynamic_per_token_int8_quant_ref( + reshaped_x, True, 8 + ) + + out = torch.ops._xpu_C.int4_gemm_w4a8( + quant_x, + x_scale, + x_zero, + layer.weight_packed.t(), + layer.weight_scale, + layer.weight_zero_point, + self.config.group_size, + None, # g_idx not currently supported + bias, + ) + + return out.to(x.dtype)