[Misc] Disambiguate quantized types via a new ScalarType (#6396)
This commit is contained in:
@@ -10,11 +10,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points,
|
||||
check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, replace_tensor, verify_awq_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -22,20 +22,31 @@ logger = init_logger(__name__)
|
||||
class AWQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for AWQ Marlin"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
|
||||
lm_head_quantized: bool) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.pack_factor = 32 // self.weight_bits # packed into 32bits
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
verify_awq_marlin_supported(num_bits=self.weight_bits,
|
||||
group_size=self.group_size,
|
||||
has_zp=self.has_zp)
|
||||
if weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {weight_bits}. "
|
||||
f"Supported num_bits = {self.TYPE_MAP.keys()}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[weight_bits]
|
||||
|
||||
verify_marlin_supported(self.quant_type,
|
||||
group_size=self.group_size,
|
||||
has_zp=self.has_zp)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"has_zp={self.has_zp}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
@@ -110,11 +121,13 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
if (num_bits is None or group_size is None or has_zp is None):
|
||||
return False
|
||||
|
||||
return check_awq_marlin_supported(
|
||||
num_bits=num_bits,
|
||||
group_size=group_size,
|
||||
has_zp=has_zp,
|
||||
min_capability=cls.get_min_capability())
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
|
||||
group_size=group_size,
|
||||
has_zp=has_zp,
|
||||
min_capability=cls.get_min_capability())
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
@@ -226,7 +239,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_tensor(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from AWQ format to marlin format.
|
||||
@@ -242,7 +255,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.qzeros,
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_tensor(layer, "qzeros", marlin_zp)
|
||||
|
||||
# Not-used
|
||||
@@ -263,7 +276,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
Reference in New Issue
Block a user