Marlin 24 prefill performance improvement (about 25% better on average) (#4983)

This commit is contained in:
Alexander Matveev
2024-05-23 02:39:27 -04:00
committed by GitHub
parent ee3eea0a1b
commit 6066253296
4 changed files with 107 additions and 32 deletions

View File

@@ -15,7 +15,7 @@ logger = init_logger(__name__)
GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 16
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
@@ -53,14 +53,14 @@ class GPTQMarlin24Config(QuantizationConfig):
self.tile_size = 16
# Min out_features dim
self.min_n_threads = 128
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
# Min in_features dim
self.min_k_threads = 128
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
# Permutation length used by the marlin kernels.
self.perm_len = 1024