[Misc] Disambiguate quantized types via a new ScalarType (#6396)
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
|
||||
scalar_types.uint4b8, scalar_types.uint8b128
|
||||
]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
quant_type = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}.get(weight_bits)
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
||||
if quant_type is None or \
|
||||
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
|
||||
f"Marlin_24 does not support quant_type = {quant_type}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
|
||||
"are supported.")
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
self.quant_type = quant_type
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
self.pack_factor = 32 // self.quant_type.size_bits
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Marlin24Config(weight_bits={}, group_size={})".format(
|
||||
self.weight_bits, self.group_size)
|
||||
return "Marlin24Config(quant_type={}, group_size={})".format(
|
||||
self.quant_type, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||
workspace,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.quant_type,
|
||||
size_m, size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
Reference in New Issue
Block a user