[Kernel] Initial Machete W4A8 support + Refactors (#9855)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2024-11-18 14:59:29 -05:00
committed by GitHub
parent c2170a5b39
commit 96d999fbe8
28 changed files with 2616 additions and 1694 deletions

View File

@@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor,
def quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False):
assert quant_type.is_integer(), \
"Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, \
"to have group zero points, group_size must be provided "\
"(-1 group_size is channelwise)"
orig_device = w.device
orig_type = w.dtype
@@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor,
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Reshape to [groupsize, -1]
if group_size < size_k:
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
@@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor,
max_q_val = quant_type.max()
min_q_val = quant_type.min()
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
.clamp(min_q_val, max_q_val).int()
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
maybe_w_zp = None
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
.clamp(min_q_val, max_q_val).int()
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
@@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor,
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and zero_points:
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
@@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor,
w_q += quant_type.bias
# Restore original shapes
if group_size < size_k:
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
@@ -195,17 +199,16 @@ def quantize_weights(w: torch.Tensor,
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
w_s = w_s.reshape((-1, size_n)).contiguous()
if zero_points:
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)