[Misc] Disambiguate quantized types via a new ScalarType (#6396)
This commit is contained in:
@@ -4,7 +4,11 @@ from typing import List
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
SUPPORTED_NUM_BITS = [4, 8]
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Note: this is a hack. We should update each model to register the
|
||||
@@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
@@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||
)
|
||||
|
||||
|
||||
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
act_order: bool):
|
||||
def quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
zero_points: bool = False):
|
||||
assert quant_type.is_integer(), \
|
||||
"Floating point quantization may work but has not been tested"
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
max_q_val = 2**num_bits - 1
|
||||
half_q_val = (max_q_val + 1) // 2
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
@@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
||||
s *= 2 / max_q_val # 2 => symmetric
|
||||
max_val = torch.max(w, 0, keepdim=True).values
|
||||
min_val = torch.min(w, 0, keepdim=True).values
|
||||
|
||||
max_q_val = quant_type.max()
|
||||
min_q_val = quant_type.min()
|
||||
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
||||
.clamp(min_q_val, max_q_val).int()
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
||||
maybe_w_zp = None
|
||||
|
||||
# Quantize
|
||||
q_w = torch.round(w / s).int()
|
||||
q_w += half_q_val
|
||||
q_w = torch.clamp(q_w, 0, max_q_val)
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
w_ref = (q_w - half_q_val).half() * s
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
|
||||
if quant_type.has_bias():
|
||||
w_q += quant_type.bias
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
@@ -119,10 +140,35 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
q_w = reshape_w(q_w)
|
||||
w_q = reshape_w(w_q)
|
||||
w_ref = reshape_w(w_ref)
|
||||
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
if zero_points:
|
||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
w_q.to(device=orig_device),
|
||||
w_s.to(device=orig_device),
|
||||
maybe_w_zp,
|
||||
)
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
||||
group_size: int, act_order: bool):
|
||||
size_k, _ = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
|
||||
f"Unsupported gptq type = {quant_type}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
||||
|
||||
# Apply act_order
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
@@ -133,76 +179,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||
group_size, size_k)
|
||||
|
||||
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
s.to(device=orig_device),
|
||||
g_idx.to(device=orig_device),
|
||||
rand_perm.to(device=orig_device),
|
||||
)
|
||||
|
||||
|
||||
def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
orig_device = w.device
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
max_q_val = 2**num_bits - 1
|
||||
min_q_val = 0
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
max = torch.max(w, 0, keepdim=True)[0]
|
||||
min = torch.min(w, 0, keepdim=True)[0]
|
||||
s = (max - min).clamp(min=1e-5) / max_q_val
|
||||
|
||||
# Compute zero-point for each group
|
||||
zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
|
||||
|
||||
# Quantize
|
||||
q_w = torch.round(w / s).int() + zp
|
||||
q_w = torch.clamp(q_w, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
w_ref = (q_w - zp).half() * s
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
q_w = reshape_w(q_w)
|
||||
w_ref = reshape_w(w_ref)
|
||||
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
s.to(device=orig_device),
|
||||
zp.to(device=orig_device),
|
||||
)
|
||||
return w_ref, w_q, w_s, g_idx, rand_perm
|
||||
|
||||
|
||||
# QQQ employs different quant schemes for per-group and
|
||||
@@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
|
||||
f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
Reference in New Issue
Block a user