[CPU] Support CT W4A16 on CPU MP kernel (#38219)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2026-03-27 14:15:28 +08:00
committed by GitHub
parent a8eab8f30d
commit becaed6ec8
2 changed files with 42 additions and 20 deletions

View File

@@ -11,6 +11,7 @@ MODELS = [
"TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ", "TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ",
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", # with g_idx "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", # with g_idx
"Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4", # without g_idx "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4", # without g_idx
"RedHatAI/Qwen3-1.7B-quantized.w4a16", # with zp
] ]
DTYPE = ["bfloat16"] DTYPE = ["bfloat16"]

View File

@@ -58,28 +58,29 @@ class CPUWNA16LinearKernel(MPLinearKernel):
return True, None return True, None
# note assumes that # note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} (marlin)
# `weight_scale` is: {input_dim = 0, output_dim = 1} # or: {input_dim = 1, output_dim = 0, packed_dim = 1} (CT)
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1} # `weight_scale` is: {input_dim = 0, output_dim = 1} (marlin)
# or: {input_dim = 1, output_dim = 0} (CT)
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1} (marlin)
# or: {input_dim = 1, output_dim = 0, packed_dim = 0} (CT)
def _process_gptq_weights(self, layer: torch.nn.Module): def _process_gptq_weights(self, layer: torch.nn.Module):
packed_weight = layer.qweight.data packed_weight = getattr(layer, self.w_q_name)
assert packed_weight.input_dim == packed_weight.packed_dim
is_ct_format = packed_weight.input_dim == 1
if is_ct_format:
packed_weight = packed_weight.t()
bits = self.config.weight_type.mantissa bits = self.config.weight_type.mantissa
pack_factor = 32 // bits pack_factor = 32 // bits
p_w_k, p_w_n = packed_weight.size() p_w_k, _ = packed_weight.size()
input_size = p_w_k * pack_factor input_size = p_w_k * pack_factor
output_size = p_w_n isa_hint = _get_isa_hint(getattr(layer, self.w_s_name).dtype)
isa_hint = _get_isa_hint(layer.scales.dtype)
layer.isa_hint = isa_hint layer.isa_hint = isa_hint
layer.qzeros = None
if not self.config.has_g_idx:
layer.g_idx = None
# convert input dim packed to output dim packed # convert input dim packed to output dim packed
weight = unpack_quantized_values_into_int32( weight = unpack_quantized_values_into_int32(
packed_weight, self.config.weight_type, 1 packed_weight, self.config.weight_type, 0
).view(p_w_k, p_w_n, pack_factor) )
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
weight = pack_quantized_values_into_int32(weight, self.config.weight_type, 1) weight = pack_quantized_values_into_int32(weight, self.config.weight_type, 1)
# make 16 output channel as a block and transpose to the make # make 16 output channel as a block and transpose to the make
# the block contiguous # the block contiguous
@@ -89,10 +90,29 @@ class CPUWNA16LinearKernel(MPLinearKernel):
.reshape(-1, input_size * 16 // pack_factor) .reshape(-1, input_size * 16 // pack_factor)
.contiguous() .contiguous()
) )
layer.qweight.data = weight getattr(layer, self.w_q_name).data = weight
# transpose scale, zp for CT format
if is_ct_format:
scales = getattr(layer, self.w_s_name)
scales.data = scales.t().contiguous()
if self.config.zero_points:
assert self.w_zp_name
zp = getattr(layer, self.w_zp_name)
zp.data = zp.t().contiguous()
def process_weights_after_loading(self, layer: torch.nn.Module): def process_weights_after_loading(self, layer: torch.nn.Module):
if not self.config.zero_points: if (not self.config.zero_points) and (self.w_zp_name is not None):
setattr(layer, self.w_zp_name, None)
if (not self.config.has_g_idx) and (self.w_gidx_name is not None):
setattr(layer, self.w_gidx_name, None)
w_input_dim = getattr(layer, self.w_q_name).input_dim
w_pack_dim = getattr(layer, self.w_q_name).packed_dim
quant_method = "gptq" if w_pack_dim == w_input_dim else "awq"
if quant_method == "gptq":
# GPTQ # GPTQ
self._process_gptq_weights(layer) self._process_gptq_weights(layer)
else: else:
@@ -105,12 +125,13 @@ class CPUWNA16LinearKernel(MPLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
x = ops.cpu_gemm_wna16( x = ops.cpu_gemm_wna16(
input=x, input=x,
q_weight=layer.qweight, q_weight=w_q,
scales=layer.scales, scales=w_s,
zeros=layer.qzeros, zeros=w_zp,
g_idx=layer.g_idx, g_idx=w_gidx,
bias=bias, bias=bias,
pack_factor=8, # 32 // 4 pack_factor=8, # 32 // 4
isa_hint=layer.isa_hint, isa_hint=layer.isa_hint,