Files
vllm/vllm/model_executor/layers/quantization/utils/quant_utils.py

415 lines
12 KiB
Python
Raw Normal View History

"""This file is used for /tests and /benchmarks"""
from typing import List
import numpy
import torch
SUPPORTED_NUM_BITS = [4, 8]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in ignored_layers
assert is_skipped is not None
return is_skipped
def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
assert q_w.shape == w_ref.shape
orig_device = q_w.device
k_size, _ = q_w.shape
g_idx = torch.zeros((k_size, ), dtype=torch.int32)
for i in range(k_size):
g_idx[i] = i // group_size
# Simulate act_order by doing a random permutation on K
rand_perm = torch.randperm(k_size)
g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous()
w_ref = w_ref[rand_perm, :].contiguous()
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
act_order: bool):
orig_device = w.device
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / max_q_val # 2 => symmetric
# Quantize
q_w = torch.round(w / s).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s
# Restore original shapes
if group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
# Apply act_order
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
assert (
group_size < size_k
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
orig_device = w.device
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
min_q_val = 0
# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max = torch.max(w, 0, keepdim=True)[0]
min = torch.min(w, 0, keepdim=True)[0]
s = (max - min).clamp(min=1e-5) / max_q_val
# Compute zero-point for each group
zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
# Quantize
q_w = torch.round(w / s).int() + zp
q_w = torch.clamp(q_w, min_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - zp).half() * s
# Restore original shapes
if group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
zp = zp.reshape((-1, size_n)).contiguous()
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
zp.to(device=orig_device),
)
# QQQ employs different quant schemes for per-group and
# per-channel quantization.
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
orig_device = w.device
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
if group_size < size_k:
# Reshape to [groupsize, -1]
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Compute scale for each group
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
s_group *= 2 / max_q_val # 2 => symmetric
# Quantize
q_w = torch.round(w / s_group).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s_group
# Restore original shapes
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
# Compute int8 quantization scale for each channel
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
s_channel /= 127.0
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
w_ref = t_int8.half() * s_channel
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
# Fuse scales
s_group = (s_group.reshape(-1, size_n).contiguous() /
s_channel).to(dtype=torch.half)
else:
max_q_val = 2**(num_bits - 1) - 1
# Compute scale for each channel
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
s_channel /= max_q_val
# Quantize
q_w = torch.round(w / s_channel).int()
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = q_w.half() * s_channel
s_group = torch.tensor([], dtype=torch.half)
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
s_channel /= (2**(8 - num_bits))
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s_group.to(device=orig_device),
s_channel.to(device=orig_device),
)
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
orig_device = q_w.device
sort_indices = torch.argsort(g_idx).to(
dtype=torch.int32) # Sort based on g_idx
g_idx = g_idx[sort_indices].contiguous()
q_w = q_w[sort_indices, :].contiguous()
return (
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
sort_indices.to(device=orig_device),
)
def pack_rows(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_k % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[i::pack_factor, :] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
return q_res
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k, size_n // pack_factor
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def gptq_pack(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
return pack_rows(q_w, num_bits, size_k, size_n)
def awq_pack(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
q_w = q_w.reshape((-1, size_n)).contiguous()
return pack_cols(q_w, num_bits, size_k, size_n)