[Feat]: Add support for Dynamic Quant 4 bit CPU kleidiai kernels (#17112)

Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Nikhil Gupta
2025-07-28 21:55:15 +01:00
committed by GitHub
parent c6f36cfa26
commit 89ac266b26
5 changed files with 269 additions and 11 deletions

View File

@@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp
BitBLASLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
ConchLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
Dynamic4bitLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
@@ -25,6 +27,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
Dynamic4bitLinearKernel,
BitBLASLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
@@ -56,7 +59,8 @@ def choose_mp_linear_kernel(
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
@@ -64,12 +68,12 @@ def choose_mp_linear_kernel(
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
if kernel.get_min_capability() > compute_capability:
if (compute_capability is not None
and kernel.get_min_capability() > compute_capability):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
f"{kernel.get_min_capability()}, current compute "
f" capability is {compute_capability}")
continue
can_implement, failure_reason = kernel.can_implement(config)

View File

@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class Dynamic4bitLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
@classmethod
def get_min_capability(cls) -> int:
return 1
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_cpu():
return False, "Only CPU is supported"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Unsupported quant type {c.weight_type}"
if current_platform.get_cpu_architecture(
) == CpuArchEnum.ARM and c.act_type not in [
torch.float32,
]:
return False, "Dynamic4bitLinearKernel on Arm requires"\
" Float32 activations"
if c.full_weight_shape[0] % c.group_size != 0:
return False, f"Group size ({c.group_size}) does not evenly divide"\
" the number of input features "\
f"({c.full_weight_shape[0]})"
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
# Attempt to retrieve the operation
_ = torch.ops.aten._dyn_quant_matmul_4bit
except AttributeError:
return False, f"PyTorch {torch.__version__} does not support"\
" _dyn_quant_matmul_4bit. Install a newer version"
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
packed_weight = getattr(layer, self.w_q_name)
packed_weight = packed_weight.add(8)
uint8_packed = (packed_weight[::, 1::2] << 4
| packed_weight[::, ::2]).to(torch.uint8)
scales = getattr(layer, self.w_s_name)
block_size = c.group_size
# Handle scaling factors for partitioned weights
if block_size == c.partition_weight_shape[0]:
scales = scales.to(
torch.float32
) # Float32 & Bfloat16 variants requires float32 scales
scales = scales.view(-1, 1) # Channel-wise scales
if layer.bias is not None:
layer.bias = layer.bias.to(
torch.float32
) # Float32 & Bfloat16 variants requires float32 bias
else:
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
scales = scales.to(torch.bfloat16)
# Repack weights as per kernel requirement
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
uint8_packed, scales, layer.bias, block_size,
c.partition_weight_shape[0], c.partition_weight_shape[1])
replace_parameter(layer, self.w_q_name,
torch.nn.Parameter(w, requires_grad=False))
setattr(layer, self.w_s_name, None)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
w_q = getattr(layer, self.w_q_name)
output = torch.ops.aten._dyn_quant_matmul_4bit(
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
c.partition_weight_shape[1])
return output.reshape(out_shape)