From fd03538bf97cd7f4fedd6da4584c89635878174f Mon Sep 17 00:00:00 2001 From: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Date: Thu, 5 Feb 2026 06:26:09 +0000 Subject: [PATCH] [CPU][BugFix] Allow w8a8 oneDNN quantized matmul to support 3D inputs (#33727) Signed-off-by: Fadi Arafeh --- .../layers/quantization/kernels/scaled_mm/cpu.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index b82f5781c..3d67a73af 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -182,6 +182,8 @@ class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + x_shape = x.shape + x = x.reshape(-1, x_shape[-1]) if len(x_shape) > 2 else x w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: @@ -195,7 +197,7 @@ class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): n = self.dnnl_handler.n out = torch.empty((m, n), dtype=x.dtype) ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias) - + out = out.reshape(x_shape[:-1] + (n,)) if len(x_shape) > 2 else out return out def _apply_weights_sgl(