[CPU] Refactor CPU WNA16 (#28826)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -2702,6 +2702,31 @@ def cpu_attention_with_kv_cache(
|
||||
)
|
||||
|
||||
|
||||
def cpu_gemm_wna16(
|
||||
input: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor | None,
|
||||
g_idx: torch.Tensor | None,
|
||||
bias: torch.Tensor | None,
|
||||
pack_factor: int,
|
||||
isa_hint: str,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty((input.size(0), scales.size(1)), dtype=input.dtype)
|
||||
torch.ops._C.cpu_gemm_wna16(
|
||||
input,
|
||||
q_weight,
|
||||
output,
|
||||
scales,
|
||||
zeros,
|
||||
g_idx,
|
||||
bias,
|
||||
pack_factor,
|
||||
isa_hint,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
|
||||
|
||||
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
|
||||
|
||||
@@ -1020,6 +1020,8 @@ class ModelConfig:
|
||||
# Ensure heavy backends are probed last to avoid unnecessary
|
||||
# imports during override detection (e.g., MXFP4 imports Triton)
|
||||
"mxfp4",
|
||||
"cpu_gptq",
|
||||
"cpu_awq",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
|
||||
@@ -50,7 +50,6 @@ if TYPE_CHECKING:
|
||||
VLLM_CPU_KVCACHE_SPACE: int | None = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
|
||||
VLLM_CPU_MOE_PREPACK: bool = True
|
||||
VLLM_CPU_SGL_KERNEL: bool = False
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
@@ -665,10 +664,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
)
|
||||
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ
|
||||
else None,
|
||||
# (CPU backend only) whether to use prepack for MoE layer. This will be
|
||||
# passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might
|
||||
# need to set this to "0" (False).
|
||||
"VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),
|
||||
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
|
||||
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
|
||||
# If the env var is set, Ray Compiled Graph uses the specified
|
||||
|
||||
@@ -6,7 +6,6 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
|
||||
|
||||
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -130,54 +129,6 @@ def select_experts(
|
||||
)
|
||||
|
||||
|
||||
class IPEXFusedMOE:
|
||||
def __init__(self, layer: torch.nn.Module) -> None:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", f"{activation} is not supported."
|
||||
assert not apply_router_weight_on_input
|
||||
assert routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {routed_scaling_factor} is not supported."
|
||||
)
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
class SGLFusedMOE:
|
||||
def __init__(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
@@ -260,7 +260,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w2_weight.copy_(packed_w2_weight)
|
||||
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
|
||||
else:
|
||||
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
else:
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@ QuantizationMethods = Literal[
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
"cpu_gptq",
|
||||
"cpu_awq",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@@ -107,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from .cpu_wna16 import CPUAWQConfig, CPUGPTQConfig
|
||||
from .deepspeedfp import DeepSpeedFPConfig
|
||||
from .experts_int8 import ExpertsInt8Config
|
||||
from .fbgemm_fp8 import FBGEMMFp8Config
|
||||
@@ -159,6 +162,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"cpu_gptq": CPUGPTQConfig,
|
||||
"cpu_awq": CPUAWQConfig,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
625
vllm/model_executor/layers/quantization/cpu_wna16.py
Normal file
625
vllm/model_executor/layers/quantization/cpu_wna16.py
Normal file
@@ -0,0 +1,625 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
|
||||
from vllm._custom_ops import (
|
||||
cpu_gemm_wna16,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
pack_cols,
|
||||
unpack_cols,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUGPTQConfig(QuantizationConfig):
|
||||
"""Config class for CPU GPTQ quant"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, int | bool]],
|
||||
full_config: dict[str, Any],
|
||||
modules_in_block_to_quantize: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
assert weight_bits == 4
|
||||
self.dynamic = dynamic
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.full_config = full_config
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"CPUWNA16Config("
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "cpu_gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CPUGPTQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
desc_act,
|
||||
is_sym,
|
||||
lm_head_quantized,
|
||||
dynamic,
|
||||
config,
|
||||
modules_in_block_to_quantize,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
if current_platform.is_cpu() and (quant_method == "gptq"):
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
return get_linear_quant_method(self, layer, prefix, CPUGPTQLinearMethod) # type: ignore
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item
|
||||
for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class CPUGPTQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ on CPU.
|
||||
|
||||
Args:
|
||||
quant_config: The CPUWNA16 quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CPUGPTQConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.is_sym, "GPTQ asym quant is not supported on CPU"
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
assert output_size_per_partition * self.quant_config.weight_bits % 32 == 0
|
||||
assert output_size_per_partition % 32 == 0
|
||||
assert input_size_per_partition % 32 == 0
|
||||
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(
|
||||
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
|
||||
):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each rank in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Activation order
|
||||
g_idx = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
set_weight_attrs(
|
||||
g_idx,
|
||||
{"ignore_warning": True},
|
||||
)
|
||||
|
||||
qzeros_args = {
|
||||
"data": torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
|
||||
packed_weight = layer.qweight.data
|
||||
bits = self.quant_config.weight_bits
|
||||
pack_factor = int(self.quant_config.pack_factor)
|
||||
p_w_k, p_w_n = packed_weight.size()
|
||||
input_size = p_w_k * pack_factor
|
||||
output_size = p_w_n
|
||||
isa_hint = _get_isa_hint(layer.scales.dtype)
|
||||
layer.isa_hint = isa_hint
|
||||
|
||||
layer.qzeros = None
|
||||
if not self.quant_config.desc_act:
|
||||
layer.g_idx = None
|
||||
|
||||
# convert input dim packed to output dim packed
|
||||
weight = unpack_cols(packed_weight, bits, p_w_k, p_w_n * pack_factor).view(
|
||||
p_w_k, p_w_n, pack_factor
|
||||
)
|
||||
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
|
||||
weight = pack_cols(weight, bits, input_size, output_size)
|
||||
# make 16 output channel as a block and transpose to the make
|
||||
# the block contigous
|
||||
weight = (
|
||||
weight.view(input_size, -1, 16 // pack_factor)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, input_size * 16 // pack_factor)
|
||||
.contiguous()
|
||||
)
|
||||
layer.qweight.data = weight
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x = cpu_gemm_wna16(
|
||||
input=x,
|
||||
q_weight=layer.qweight,
|
||||
scales=layer.scales,
|
||||
zeros=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
bias=bias,
|
||||
pack_factor=8,
|
||||
isa_hint=layer.isa_hint,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class CPUAWQConfig(QuantizationConfig):
|
||||
"""Config class for CPU AWQ"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: list[str] | None,
|
||||
full_config: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert weight_bits == 4
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.full_config = full_config
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AWQMarlinConfig("
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> "QuantizationMethods":
|
||||
return "cpu_awq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CPUAWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
zero_point,
|
||||
lm_head_quantized,
|
||||
modules_to_not_convert,
|
||||
config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional["QuantizationMethods"]:
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
if current_platform.is_cpu() and (quant_method == "awq"):
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
if is_layer_skipped(
|
||||
prefix,
|
||||
self.modules_to_not_convert,
|
||||
self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return CPUAWQLinearMethod(self)
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_to_not_convert:
|
||||
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_to_not_convert = list(layers - quant_layers)
|
||||
|
||||
|
||||
class CPUAWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for CPU AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The CPU AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CPUAWQConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.zero_point
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
scales = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
|
||||
packed_weight = layer.qweight.data
|
||||
packed_zeros = layer.qzeros.data
|
||||
group_num = packed_zeros.size(0)
|
||||
bits = self.quant_config.weight_bits
|
||||
pack_factor = int(self.quant_config.pack_factor)
|
||||
input_size, packed_output_size = packed_weight.size()
|
||||
output_size = packed_output_size * pack_factor
|
||||
isa_hint = _get_isa_hint(layer.scales.dtype)
|
||||
layer.isa_hint = isa_hint
|
||||
|
||||
interleave_map = (0, 4, 1, 5, 2, 6, 3, 7)
|
||||
weight = unpack_cols(
|
||||
packed_weight,
|
||||
bits,
|
||||
input_size,
|
||||
output_size,
|
||||
)
|
||||
zeros = unpack_cols(
|
||||
packed_zeros,
|
||||
bits,
|
||||
group_num,
|
||||
output_size,
|
||||
)
|
||||
weight = (
|
||||
weight.view(input_size, -1, pack_factor)[:, :, interleave_map]
|
||||
.reshape(input_size, output_size)
|
||||
.contiguous()
|
||||
)
|
||||
zeros = (
|
||||
zeros.view(group_num, -1, pack_factor)[:, :, interleave_map]
|
||||
.reshape(group_num, output_size)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
zeros = pack_cols(zeros, bits, group_num, output_size).contiguous()
|
||||
# make 16 output channel as a block and transpose to
|
||||
# the make the block contigous
|
||||
weight = pack_cols(weight, bits, input_size, output_size)
|
||||
weight = (
|
||||
weight.view(input_size, -1, 16 // pack_factor)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, input_size * 16 // pack_factor)
|
||||
.contiguous()
|
||||
)
|
||||
layer.qweight.data = weight
|
||||
layer.qzeros.data = zeros
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x = cpu_gemm_wna16(
|
||||
input=x,
|
||||
q_weight=layer.qweight,
|
||||
scales=layer.scales,
|
||||
zeros=layer.qzeros,
|
||||
g_idx=None,
|
||||
bias=bias,
|
||||
pack_factor=8,
|
||||
isa_hint=layer.isa_hint,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def _get_isa_hint(dtype: torch.dtype) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,):
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
@@ -134,7 +134,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
||||
if not current_platform.is_xpu():
|
||||
return None
|
||||
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
Reference in New Issue
Block a user