[Quantization][Refactor] Move CPU GPTQ kernel into MP linear (#31801)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Li, Jiang <bigpyj64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Li, Jiang
2026-01-07 03:10:18 +08:00
committed by GitHub
parent c07163663d
commit 8becf146bd
9 changed files with 171 additions and 332 deletions

View File

@@ -10,6 +10,7 @@ if not current_platform.is_cpu():
MODELS = [
"TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ",
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", # with g_idx
"Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4", # without g_idx
]
DTYPE = ["bfloat16"]

View File

@@ -876,7 +876,6 @@ class ModelConfig:
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
"cpu_gptq",
"cpu_awq",
]
quantization_methods = [

View File

@@ -38,7 +38,6 @@ QuantizationMethods = Literal[
"inc",
"mxfp4",
"petit_nvfp4",
"cpu_gptq",
"cpu_awq",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -109,7 +108,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from .cpu_wna16 import CPUAWQConfig, CPUGPTQConfig
from .cpu_wna16 import CPUAWQConfig
from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
@@ -162,7 +161,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_gptq": CPUGPTQConfig,
"cpu_awq": CPUAWQConfig,
}
# Update the `method_to_config` with customized quantization methods.

View File

@@ -20,12 +20,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
pack_cols,
@@ -34,335 +28,15 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils.collection_utils import is_list_of
logger = init_logger(__name__)
class CPUGPTQConfig(QuantizationConfig):
"""Config class for CPU GPTQ quant"""
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, int | bool]],
full_config: dict[str, Any],
modules_in_block_to_quantize: list[str] | None = None,
) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
assert weight_bits == 4
self.dynamic = dynamic
self.weight_bits = weight_bits
self.is_sym = is_sym
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.full_config = full_config
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
def __repr__(self) -> str:
return (
f"CPUWNA16Config("
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "cpu_gptq"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return -1
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "CPUGPTQConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
group_size = cls.get_from_keys(config, ["group_size"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None
)
return cls(
weight_bits,
group_size,
desc_act,
is_sym,
lm_head_quantized,
dynamic,
config,
modules_in_block_to_quantize,
)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
quant_method = hf_quant_cfg.get("quant_method", "").lower()
if current_platform.is_cpu() and (quant_method == "gptq"):
return cls.get_name()
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
return get_linear_quant_method(self, layer, prefix, CPUGPTQLinearMethod) # type: ignore
def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize
)
def maybe_update_config(self, model_name: str, revision: str | None = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item
for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name, revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get("dtype", None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
class CPUGPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ on CPU.
Args:
quant_config: The CPUWNA16 quantization config.
"""
def __init__(self, quant_config: CPUGPTQConfig) -> None:
self.quant_config = quant_config
assert self.quant_config.is_sym, "GPTQ asym quant is not supported on CPU"
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
assert output_size_per_partition * self.quant_config.weight_bits % 32 == 0
assert output_size_per_partition % 32 == 0
assert input_size_per_partition % 32 == 0
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Determine sharding
if marlin_repeat_scales_on_all_ranks(
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each rank in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
# Activation order
g_idx = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(
g_idx,
{"ignore_warning": True},
)
qzeros_args = {
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader": weight_loader,
}
weight_scale_args = {
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader": weight_loader,
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
else:
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
packed_weight = layer.qweight.data
bits = self.quant_config.weight_bits
pack_factor = int(self.quant_config.pack_factor)
p_w_k, p_w_n = packed_weight.size()
input_size = p_w_k * pack_factor
output_size = p_w_n
isa_hint = _get_isa_hint(layer.scales.dtype)
layer.isa_hint = isa_hint
layer.qzeros = None
if not self.quant_config.desc_act:
layer.g_idx = None
# convert input dim packed to output dim packed
weight = unpack_cols(packed_weight, bits, p_w_k, p_w_n * pack_factor).view(
p_w_k, p_w_n, pack_factor
)
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
weight = pack_cols(weight, bits, input_size, output_size)
# make 16 output channel as a block and transpose to the make
# the block contigous
weight = (
weight.view(input_size, -1, 16 // pack_factor)
.permute(1, 0, 2)
.reshape(-1, input_size * 16 // pack_factor)
.contiguous()
)
layer.qweight.data = weight
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
x = cpu_gemm_wna16(
input=x,
q_weight=layer.qweight,
scales=layer.scales,
zeros=layer.qzeros,
g_idx=layer.g_idx,
bias=bias,
pack_factor=8,
isa_hint=layer.isa_hint,
)
return x
class CPUAWQConfig(QuantizationConfig):
"""Config class for CPU AWQ"""

View File

@@ -276,7 +276,7 @@ class GPTQMarlinConfig(QuantizationConfig):
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
if not current_platform.is_cuda():
if not (current_platform.is_cuda() or current_platform.is_cpu()):
return False
if quant_method != "gptq":

View File

@@ -11,6 +11,9 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
ConchLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cpu import ( # noqa: E501
CPUWNA16LinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
CutlassW4A8LinearKernel,
)
@@ -46,6 +49,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
ConchLinearKernel,
ExllamaLinearKernel,
XPUwNa16LinearKernel,
CPUWNA16LinearKernel,
]

View File

@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32,
unpack_quantized_values_into_int32,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_CPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
class CPUWNA16LinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return -1
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "CPUWNA16 only supported on CPU"
if c.weight_type not in _CPUWNA16_SUPPORTED_QUANT_TYPES:
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"CPUWNA16, supported types are: "
f"{_CPUWNA16_SUPPORTED_QUANT_TYPES}",
)
if c.group_size != -1 and c.group_size % 2 != 0:
return (
False,
f"Group size ({c.group_size}) not supported by "
"CPUWNA16, supported group sizes are multiples of 2",
)
if c.partition_weight_shape[0] % 32 != 0:
return (
False,
f"Input size ({c.partition_weight_shape[0]}) not supported by "
"CPUWNA16, supported sizes are multiples of 32",
)
if c.partition_weight_shape[1] % 32 != 0:
return (
False,
f"Output size ({c.partition_weight_shape[1]}) not supported by "
"CPUWNA16, supported sizes are multiples of 32",
)
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def _process_gptq_weights(self, layer: torch.nn.Module):
packed_weight = layer.qweight.data
bits = self.config.weight_type.mantissa
pack_factor = 32 // bits
p_w_k, p_w_n = packed_weight.size()
input_size = p_w_k * pack_factor
output_size = p_w_n
isa_hint = _get_isa_hint(layer.scales.dtype)
layer.isa_hint = isa_hint
layer.qzeros = None
if not self.config.has_g_idx:
layer.g_idx = None
# convert input dim packed to output dim packed
weight = unpack_quantized_values_into_int32(
packed_weight, self.config.weight_type, 1
).view(p_w_k, p_w_n, pack_factor)
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
weight = pack_quantized_values_into_int32(weight, self.config.weight_type, 1)
# make 16 output channel as a block and transpose to the make
# the block contigous
weight = (
weight.view(input_size, -1, 16 // pack_factor)
.permute(1, 0, 2)
.reshape(-1, input_size * 16 // pack_factor)
.contiguous()
)
layer.qweight.data = weight
def process_weights_after_loading(self, layer: torch.nn.Module):
if not self.config.zero_points:
# GPTQ
self._process_gptq_weights(layer)
else:
# AWQ
raise NotImplementedError("AWQ is not supported in CPUWNA16LinearKernel")
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
x = ops.cpu_gemm_wna16(
input=x,
q_weight=layer.qweight,
scales=layer.scales,
zeros=layer.qzeros,
g_idx=layer.g_idx,
bias=bias,
pack_factor=8, # 32 // 4
isa_hint=layer.isa_hint,
)
return x
def _get_isa_hint(dtype: torch.dtype) -> str:
supports_amx = torch._C._cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,):
return "amx"
else:
return "vec"

View File

@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
@@ -25,6 +26,12 @@ class ExllamaLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda_alike():
return (
False,
"Exllama is only supported on CUDA and ROCm",
)
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
return (
False,
@@ -104,7 +111,7 @@ class ExllamaLinearKernel(MPLinearKernel):
# indices
return torch.argsort(x).to(torch.int)
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) # type: ignore
else:
self.w_gidx_name = "g_idx"
empty_g_idx = torch.nn.Parameter(

View File

@@ -42,6 +42,9 @@ def query_marlin_supported_quant_types(
include_fp_type: bool = True,
device_capability: int | None = None,
):
if current_platform.is_cpu():
return _query_cpu_marlin_supported_quant_types(has_zp, include_fp_type)
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (
@@ -74,6 +77,33 @@ def query_marlin_supported_quant_types(
return res
def _query_cpu_marlin_supported_quant_types(
has_zp: bool | None = None,
include_fp_type: bool = True,
):
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if has_zp is None:
types0 = _query_cpu_marlin_supported_quant_types(
False,
include_fp_type,
)
types1 = _query_cpu_marlin_supported_quant_types(
True,
include_fp_type,
)
return types0 + types1
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4]
else:
# GPTQ style, unsigned + symmetric bias, only supports 4-bits for now
res = [scalar_types.uint4b8]
return res
def _check_marlin_supported(
quant_type: ScalarType,
group_size: int | None,