[Kernel] Add marlin_24 unit tests (#4901)
This commit is contained in:
committed by
GitHub
parent
f68470e803
commit
27ce85476e
@@ -12,6 +12,15 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_24_TILE = 16
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 16
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
"""Config class for Marlin24.
|
||||
@@ -25,15 +34,17 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
|
||||
if self.weight_bits != 4 and self.weight_bits != 8:
|
||||
raise ValueError("weight_bits must be 4 or 8. Got = {}".format(
|
||||
self.weight_bits))
|
||||
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) "
|
||||
"is supported for Marlin24, but got group_size of "
|
||||
f"{self.group_size}")
|
||||
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
Reference in New Issue
Block a user