[Misc] Disambiguate quantized types via a new ScalarType (#6396)

This commit is contained in:
Lucas Wilkinson
2024-08-02 16:51:58 -04:00
committed by GitHub
parent b482b9a5b1
commit a8d604ca2a
29 changed files with 1111 additions and 356 deletions

View File

@@ -10,11 +10,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points,
check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_awq_marlin_supported,
verify_marlin_supports_shape)
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -22,20 +22,31 @@ logger = init_logger(__name__)
class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin"""
# num_bits -> type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
lm_head_quantized: bool) -> None:
self.weight_bits = weight_bits
self.pack_factor = 32 // self.weight_bits # packed into 32bits
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.has_zp = has_zp
self.lm_head_quantized = lm_head_quantized
verify_awq_marlin_supported(num_bits=self.weight_bits,
group_size=self.group_size,
has_zp=self.has_zp)
if weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {weight_bits}. "
f"Supported num_bits = {self.TYPE_MAP.keys()}")
self.quant_type = self.TYPE_MAP[weight_bits]
verify_marlin_supported(self.quant_type,
group_size=self.group_size,
has_zp=self.has_zp)
def __repr__(self) -> str:
return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, "
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"has_zp={self.has_zp}, "
f"lm_head_quantized={self.lm_head_quantized})")
@@ -110,11 +121,13 @@ class AWQMarlinConfig(QuantizationConfig):
if (num_bits is None or group_size is None or has_zp is None):
return False
return check_awq_marlin_supported(
num_bits=num_bits,
group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
if num_bits not in cls.TYPE_MAP:
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
class AWQMarlinLinearMethod(LinearMethodBase):
@@ -226,7 +239,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
@@ -242,7 +255,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qzeros", marlin_zp)
# Not-used
@@ -263,7 +276,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias)