[Kernel] (1/N) Machete - Hopper Optimized Mixed Precision Linear Kernel (#7174)
This commit is contained in:
@@ -81,7 +81,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||
def quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
zero_points: bool = False):
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False):
|
||||
assert quant_type.is_integer(), \
|
||||
"Floating point quantization may work but has not been tested"
|
||||
|
||||
@@ -126,7 +127,13 @@ def quantize_weights(w: torch.Tensor,
|
||||
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
# For some kernels (namely Machete) the zero-points are applied after the
|
||||
# scales are applied, for this case computing the reference in similar way
|
||||
# allows us to use tighter error tolerances in our unit tests.
|
||||
if ref_zero_points_after_scales and zero_points:
|
||||
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
||||
else:
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
|
||||
if quant_type.has_bias():
|
||||
w_q += quant_type.bias
|
||||
|
||||
Reference in New Issue
Block a user