[Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795)
This commit is contained in:
committed by
GitHub
parent
fad5576c58
commit
75acdaa4b6
@@ -16,6 +16,11 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# In case there is a performance issue with Marlin, the variable below can be
|
||||
# changed to False, which allows Marlin to perform global reductions in fp16
|
||||
# precision (instead of fp32), and therefore, save on some memory movements.
|
||||
USE_FP32_REDUCE_DEFAULT = True
|
||||
|
||||
|
||||
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
|
||||
min_capability: Optional[int],
|
||||
@@ -244,7 +249,8 @@ def apply_gptq_marlin_linear(
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
@@ -260,7 +266,8 @@ def apply_gptq_marlin_linear(
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False)
|
||||
has_zp=False,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
@@ -279,7 +286,8 @@ def apply_awq_marlin_linear(
|
||||
num_bits: int,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
@@ -295,7 +303,8 @@ def apply_awq_marlin_linear(
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=True,
|
||||
has_zp=True)
|
||||
has_zp=True,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
Reference in New Issue
Block a user