[CPU][BugFix] Allow w8a8 oneDNN quantized matmul to support 3D inputs (#33727)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
@@ -182,6 +182,8 @@ class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.reshape(-1, x_shape[-1]) if len(x_shape) > 2 else x
|
||||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||||
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||||
@@ -195,7 +197,7 @@ class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
n = self.dnnl_handler.n
|
n = self.dnnl_handler.n
|
||||||
out = torch.empty((m, n), dtype=x.dtype)
|
out = torch.empty((m, n), dtype=x.dtype)
|
||||||
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
|
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
|
||||||
|
out = out.reshape(x_shape[:-1] + (n,)) if len(x_shape) > 2 else out
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _apply_weights_sgl(
|
def _apply_weights_sgl(
|
||||||
|
|||||||
Reference in New Issue
Block a user