[Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795)

This commit is contained in:
Alexander Matveev
2024-07-27 17:52:33 -04:00
committed by GitHub
parent fad5576c58
commit 75acdaa4b6
6 changed files with 168 additions and 44 deletions

View File

@@ -16,6 +16,11 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT = True
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: Optional[int],
@@ -244,7 +249,8 @@ def apply_gptq_marlin_linear(
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
@@ -260,7 +266,8 @@ def apply_gptq_marlin_linear(
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False)
has_zp=False,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
@@ -279,7 +286,8 @@ def apply_awq_marlin_linear(
num_bits: int,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
@@ -295,7 +303,8 @@ def apply_awq_marlin_linear(
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True)
has_zp=True,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add