[CPU][Feat] Enable KleidiAI INT8_W4A8 for all input dtypes (#34890)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
@@ -42,12 +42,13 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
not in [
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
]
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"Dynamic4bitLinearKernel on Arm requires Float32 or"
|
||||
" BFloat16 activations",
|
||||
" BFloat16 or Float16 activations",
|
||||
)
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return (
|
||||
@@ -118,8 +119,30 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# PyTorch / KleidiAI kernels natively support the following configs:
|
||||
# - channelwise with bfloat16 / float32 activations
|
||||
# - groupwise with float32 activations
|
||||
# To support:
|
||||
# - groupwise with bfloat16/float16 activations: we need to upcast
|
||||
# activations to float32 before matmul and downcast back to bfloat16/float16
|
||||
# - channelwise with float16 activations, we need to upcast activations to
|
||||
# float32 before matmul and downcast back to float16
|
||||
# Note: these activations will be dynamically quantized to int8 by the kernel.
|
||||
|
||||
c = self.config
|
||||
is_groupwise = c.group_size != c.partition_weight_shape[0]
|
||||
# dtype of activations before they get dynamically quantized to int8
|
||||
original_pre_quant_act_dtype = x.dtype
|
||||
pre_quant_act_dtype = original_pre_quant_act_dtype
|
||||
if (
|
||||
is_groupwise and pre_quant_act_dtype == torch.bfloat16
|
||||
) or pre_quant_act_dtype == torch.float16:
|
||||
pre_quant_act_dtype = torch.float32
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
if pre_quant_act_dtype != original_pre_quant_act_dtype:
|
||||
x_2d = x_2d.to(pre_quant_act_dtype)
|
||||
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q = getattr(layer, self.w_q_name)
|
||||
@@ -129,5 +152,8 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
c.group_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
)
|
||||
return output.reshape(out_shape)
|
||||
).reshape(out_shape)
|
||||
|
||||
if pre_quant_act_dtype != original_pre_quant_act_dtype:
|
||||
output = output.to(original_pre_quant_act_dtype)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user