[Kernel][Core] Add AWQ support to the Marlin kernel (#6612)
This commit is contained in:
committed by
GitHub
parent
25e778aa16
commit
396d92d5e0
@@ -106,6 +106,67 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
)
|
||||
|
||||
|
||||
def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
orig_device = w.device
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
max_q_val = 2**num_bits - 1
|
||||
min_q_val = 0
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
max = torch.max(w, 0, keepdim=True)[0]
|
||||
min = torch.min(w, 0, keepdim=True)[0]
|
||||
s = (max - min).clamp(min=1e-5) / max_q_val
|
||||
|
||||
# Compute zero-point for each group
|
||||
zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
|
||||
|
||||
# Quantize
|
||||
q_w = torch.round(w / s).int() + zp
|
||||
q_w = torch.clamp(q_w, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
w_ref = (q_w - zp).half() * s
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
q_w = reshape_w(q_w)
|
||||
w_ref = reshape_w(w_ref)
|
||||
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
s.to(device=orig_device),
|
||||
zp.to(device=orig_device),
|
||||
)
|
||||
|
||||
|
||||
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||
orig_device = q_w.device
|
||||
|
||||
@@ -122,7 +183,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||
)
|
||||
|
||||
|
||||
def gptq_pack(
|
||||
def pack_rows(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
@@ -144,3 +205,90 @@ def gptq_pack(
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
return q_res
|
||||
|
||||
|
||||
def pack_cols(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def unpack_cols(
|
||||
packed_q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
assert packed_q_w.shape == (
|
||||
size_k, size_n // pack_factor
|
||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor)
|
||||
|
||||
orig_device = packed_q_w.device
|
||||
|
||||
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
||||
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
for i in range(pack_factor):
|
||||
vals = packed_q_w_cpu & mask
|
||||
packed_q_w_cpu >>= num_bits
|
||||
q_res[:, i::pack_factor] = vals
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def gptq_pack(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
return pack_rows(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def awq_pack(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
# Interleave column dim (for the dequantize code) and pack it to int32
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
Reference in New Issue
Block a user