[Misc] Disambiguate quantized types via a new ScalarType (#6396)

This commit is contained in:
Lucas Wilkinson
2024-08-02 16:51:58 -04:00
committed by GitHub
parent b482b9a5b1
commit a8d604ca2a
29 changed files with 1111 additions and 356 deletions

View File

@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
scalar_types.uint4b8, scalar_types.uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
class GPTQMarlin24Config(QuantizationConfig):
@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits: int,
group_size: int,
) -> None:
self.weight_bits = weight_bits
quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}.get(weight_bits)
self.group_size = group_size
# Verify
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
if quant_type is None or \
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
raise ValueError(
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
f"Marlin_24 does not support quant_type = {quant_type}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")
self.quant_type = quant_type
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits
self.pack_factor = 32 // self.quant_type.size_bits
# Tile size used by marlin kernels.
self.tile_size = 16
@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
self.perm_len = 1024
def __repr__(self) -> str:
return "Marlin24Config(weight_bits={}, group_size={})".format(
self.weight_bits, self.group_size)
return "Marlin24Config(quant_type={}, group_size={})".format(
self.quant_type, self.group_size)
@classmethod
def get_name(cls) -> str:
@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace,
self.quant_config.weight_bits,
self.quant_config.quant_type,
size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))