[Kernel] Add marlin_24 unit tests (#4901)

This commit is contained in:
Alexander Matveev
2024-05-19 11:37:34 -04:00
committed by GitHub
parent f68470e803
commit 27ce85476e
6 changed files with 649 additions and 103 deletions

View File

@@ -12,6 +12,15 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 16
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
class GPTQMarlin24Config(QuantizationConfig):
"""Config class for Marlin24.
@@ -25,15 +34,17 @@ class GPTQMarlin24Config(QuantizationConfig):
self.weight_bits = weight_bits
self.group_size = group_size
if self.weight_bits != 4 and self.weight_bits != 8:
raise ValueError("weight_bits must be 4 or 8. Got = {}".format(
self.weight_bits))
if self.group_size != 128 and self.group_size != -1:
# Verify
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin24, but got group_size of "
f"{self.group_size}")
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin_24 does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits