[CPU][Feat] Enable KleidiAI INT8_W4A8 for all input dtypes (#34890)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Fadi Arafeh
2026-02-26 05:00:10 +00:00
committed by GitHub
parent 13025e71e8
commit 4171ff6dd9

View File

@@ -42,12 +42,13 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
not in [
torch.float32,
torch.bfloat16,
torch.float16,
]
):
return (
False,
"Dynamic4bitLinearKernel on Arm requires Float32 or"
" BFloat16 activations",
" BFloat16 or Float16 activations",
)
if c.full_weight_shape[0] % c.group_size != 0:
return (
@@ -118,8 +119,30 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# PyTorch / KleidiAI kernels natively support the following configs:
# - channelwise with bfloat16 / float32 activations
# - groupwise with float32 activations
# To support:
# - groupwise with bfloat16/float16 activations: we need to upcast
# activations to float32 before matmul and downcast back to bfloat16/float16
# - channelwise with float16 activations, we need to upcast activations to
# float32 before matmul and downcast back to float16
# Note: these activations will be dynamically quantized to int8 by the kernel.
c = self.config
is_groupwise = c.group_size != c.partition_weight_shape[0]
# dtype of activations before they get dynamically quantized to int8
original_pre_quant_act_dtype = x.dtype
pre_quant_act_dtype = original_pre_quant_act_dtype
if (
is_groupwise and pre_quant_act_dtype == torch.bfloat16
) or pre_quant_act_dtype == torch.float16:
pre_quant_act_dtype = torch.float32
x_2d = x.reshape(-1, x.shape[-1])
if pre_quant_act_dtype != original_pre_quant_act_dtype:
x_2d = x_2d.to(pre_quant_act_dtype)
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
w_q = getattr(layer, self.w_q_name)
@@ -129,5 +152,8 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
c.group_size,
c.partition_weight_shape[0],
c.partition_weight_shape[1],
)
return output.reshape(out_shape)
).reshape(out_shape)
if pre_quant_act_dtype != original_pre_quant_act_dtype:
output = output.to(original_pre_quant_act_dtype)
return output