[Kernel] Initial Machete W4A8 support + Refactors (#9855)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor,
|
||||
|
||||
def quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False):
|
||||
assert quant_type.is_integer(), \
|
||||
"Floating point quantization may work but has not been tested"
|
||||
assert not zero_points or group_size is not None, \
|
||||
"to have group zero points, group_size must be provided "\
|
||||
"(-1 group_size is channelwise)"
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
@@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor,
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
if group_size is not None and group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
@@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor,
|
||||
max_q_val = quant_type.max()
|
||||
min_q_val = quant_type.min()
|
||||
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
||||
.clamp(min_q_val, max_q_val).int()
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
||||
maybe_w_zp = None
|
||||
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
||||
maybe_w_zp = None
|
||||
if group_size is not None:
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
||||
.clamp(min_q_val, max_q_val).int()
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
||||
|
||||
# Quantize
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
@@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor,
|
||||
# For some kernels (namely Machete) the zero-points are applied after the
|
||||
# scales are applied, for this case computing the reference in similar way
|
||||
# allows us to use tighter error tolerances in our unit tests.
|
||||
if ref_zero_points_after_scales and zero_points:
|
||||
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
||||
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
||||
else:
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
@@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor,
|
||||
w_q += quant_type.bias
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
if group_size is not None and group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
@@ -195,17 +199,16 @@ def quantize_weights(w: torch.Tensor,
|
||||
|
||||
w_q = reshape_w(w_q)
|
||||
w_ref = reshape_w(w_ref)
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
if zero_points:
|
||||
if maybe_w_zp is not None:
|
||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
w_q.to(device=orig_device),
|
||||
w_s.to(device=orig_device),
|
||||
w_s if group_size is not None else None,
|
||||
maybe_w_zp,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user