[CPU][Feat] Enable KleidiAI accelerated int4 dynamic quant with BF16 activations on Arm CPUs (#33122)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
Fadi Arafeh
2026-01-31 07:16:22 +00:00
committed by GitHub
parent f3888aca83
commit 1618e25492

View File

@@ -11,6 +11,18 @@ from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
# This implementation is for the KleidiAI-accelerated w4a8int quantization
# scheme on Arm CPUs:
# torch.ops.aten._dyn_quant_matmul_4bit performs dynamic quantized matmul
# it takes:
# - int4 weights packed along with bias/scales by
# torch.ops.aten._dyn_quant_pack_4bit_weight
# - float32/bfloat16 activations
# then it leverages KleidiAI ukernels that:
# - dynamically quantize the activations to int8
# - unpack the int4 weights to int8
# - perform int8 x int8 -> int32 matmul
# - dequantize the int32 output to float32/bfloat16 outputs
class Dynamic4bitLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
@@ -29,9 +41,14 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
and c.act_type
not in [
torch.float32,
torch.bfloat16,
]
):
return False, "Dynamic4bitLinearKernel on Arm requires Float32 activations"
return (
False,
"Dynamic4bitLinearKernel on Arm requires Float32 or"
" BFloat16 activations",
)
if c.full_weight_shape[0] % c.group_size != 0:
return (
False,