Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
@@ -31,8 +32,8 @@ class GroupShape(_GroupShape):
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
PER_TENSOR: ClassVar["GroupShape"]
|
||||
PER_TOKEN: ClassVar["GroupShape"]
|
||||
|
||||
def is_per_tensor(self) -> bool:
|
||||
return self.row == -1 and self.col == -1
|
||||
@@ -56,18 +57,26 @@ class ScaleDesc:
|
||||
static: static scale if True, dynamic if False
|
||||
group_shape: group shape of the scale
|
||||
"""
|
||||
|
||||
dtype: torch.dtype
|
||||
static: bool
|
||||
group_shape: GroupShape
|
||||
|
||||
def __str__(self):
|
||||
group_shape = ('per_tensor'
|
||||
if self.group_shape == GroupShape.PER_TENSOR else
|
||||
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
||||
else str(self.group_shape)))
|
||||
group_shape = (
|
||||
"per_tensor"
|
||||
if self.group_shape == GroupShape.PER_TENSOR
|
||||
else (
|
||||
"per_token"
|
||||
if self.group_shape == GroupShape.PER_TOKEN
|
||||
else str(self.group_shape)
|
||||
)
|
||||
)
|
||||
|
||||
return (f"{fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"{'static' if self.static else 'dynamic'},{group_shape}")
|
||||
return (
|
||||
f"{fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"{'static' if self.static else 'dynamic'},{group_shape}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -79,6 +88,7 @@ class QuantKey:
|
||||
scale2: second-level scale descriptor
|
||||
symmetric: symmetric if True, asymmetric if False
|
||||
"""
|
||||
|
||||
dtype: torch.dtype
|
||||
scale: ScaleDesc
|
||||
scale2: Optional[ScaleDesc] = None
|
||||
@@ -86,9 +96,11 @@ class QuantKey:
|
||||
|
||||
def __str__(self):
|
||||
scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
|
||||
return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"scale({self.scale}),{scale2_str}"
|
||||
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||
return (
|
||||
f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"scale({self.scale}),{scale2_str}"
|
||||
f"{'a' if not self.symmetric else ''}symmetric)"
|
||||
)
|
||||
|
||||
|
||||
kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR)
|
||||
@@ -101,16 +113,16 @@ kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
|
||||
|
||||
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
|
||||
kNvfp4Quant = QuantKey(FP4_DTYPE,
|
||||
scale=kNvfp4GroupScale,
|
||||
scale2=kStaticTensorScale)
|
||||
kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||
return (
|
||||
group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1],
|
||||
)
|
||||
|
||||
|
||||
# Useful when treating N-dimensional group scaling as extended numpy-style
|
||||
@@ -131,9 +143,11 @@ def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
if t.shape[i] != s and t.shape[i] != 1:
|
||||
assert s % t.shape[i] == 0
|
||||
t = t.unsqueeze(i + 1)\
|
||||
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
|
||||
t = (
|
||||
t.unsqueeze(i + 1)
|
||||
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
|
||||
.flatten(i, i + 1)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
@@ -151,9 +165,10 @@ def scaled_quantize(
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, \
|
||||
"currently `scaled_quantize` only supports floating point dtypes " \
|
||||
assert quant_dtype.is_floating_point, (
|
||||
"currently `scaled_quantize` only supports floating point dtypes "
|
||||
"but could be extended to support other dtypes"
|
||||
)
|
||||
|
||||
finfo = torch.finfo(quant_dtype)
|
||||
|
||||
@@ -175,11 +190,13 @@ def scaled_quantize(
|
||||
|
||||
# Apply scale and convert form:
|
||||
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
|
||||
x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\
|
||||
.clamp(min=finfo.min, max=finfo.max)\
|
||||
.reshape(blk_m, blk_n, group_shape[0], group_shape[1])\
|
||||
.permute(0, 2, 1, 3)\
|
||||
x_scl_sat = (
|
||||
(x_blkd_permd * scale.unsqueeze(-1))
|
||||
.clamp(min=finfo.min, max=finfo.max)
|
||||
.reshape(blk_m, blk_n, group_shape[0], group_shape[1])
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(x.shape)
|
||||
)
|
||||
|
||||
return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
@@ -200,7 +217,8 @@ def scaled_dequantize(
|
||||
if group_shape is None:
|
||||
raise AssertionError(
|
||||
"if x_s is 1D tensor, group_shape must be provided otherwise "
|
||||
"its ambiguous which dimension to broadcast x_s to")
|
||||
"its ambiguous which dimension to broadcast x_s to"
|
||||
)
|
||||
# unsqueeze the scales for the dimension where we want to broadcast
|
||||
# across the full extent
|
||||
if group_shape[0] == x_q.shape[-2]:
|
||||
@@ -210,7 +228,8 @@ def scaled_dequantize(
|
||||
else:
|
||||
raise AssertionError(
|
||||
"if x_s is a vector we should be broadcasting it to the full "
|
||||
"extent of one of the dimensions")
|
||||
"extent of one of the dimensions"
|
||||
)
|
||||
|
||||
if group_shape is not None:
|
||||
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
|
||||
@@ -219,9 +238,9 @@ def scaled_dequantize(
|
||||
return (x_q.to(torch.float32) * x_s).to(out_dtype)
|
||||
|
||||
|
||||
def pack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
def pack_quantized_values_into_int32(
|
||||
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
||||
):
|
||||
# move dim to pack to the end
|
||||
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
||||
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
||||
@@ -241,9 +260,9 @@ def pack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
return res.permute(inv_perm)
|
||||
|
||||
|
||||
def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
def unpack_quantized_values_into_int32(
|
||||
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
||||
):
|
||||
# move dim to pack to the end
|
||||
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
||||
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
||||
@@ -265,7 +284,7 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
@@ -291,12 +310,16 @@ def is_layer_skipped(
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
"to have the same precision."
|
||||
)
|
||||
elif "experts" in prefix:
|
||||
return any([
|
||||
prefix in layer_name for layer_name in ignored_layers
|
||||
if "experts" in layer_name
|
||||
])
|
||||
return any(
|
||||
[
|
||||
prefix in layer_name
|
||||
for layer_name in ignored_layers
|
||||
if "experts" in layer_name
|
||||
]
|
||||
)
|
||||
else:
|
||||
is_skipped = prefix in ignored_layers
|
||||
|
||||
@@ -309,16 +332,18 @@ def get_pack_factor(num_bits):
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def permute_rows(q_w: torch.Tensor,
|
||||
w_ref: torch.Tensor,
|
||||
group_size: int,
|
||||
test_perm: Optional[torch.Tensor] = None):
|
||||
def permute_rows(
|
||||
q_w: torch.Tensor,
|
||||
w_ref: torch.Tensor,
|
||||
group_size: int,
|
||||
test_perm: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert q_w.shape == w_ref.shape
|
||||
|
||||
orig_device = q_w.device
|
||||
k_size, _ = q_w.shape
|
||||
|
||||
g_idx = torch.zeros((k_size, ), dtype=torch.int32)
|
||||
g_idx = torch.zeros((k_size,), dtype=torch.int32)
|
||||
for i in range(k_size):
|
||||
g_idx[i] = i // group_size
|
||||
|
||||
@@ -337,16 +362,20 @@ def permute_rows(q_w: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False):
|
||||
assert quant_type.is_integer(), \
|
||||
def quantize_weights(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False,
|
||||
):
|
||||
assert quant_type.is_integer(), (
|
||||
"Floating point quantization may work but has not been tested"
|
||||
assert not zero_points or group_size is not None, \
|
||||
"to have group zero points, group_size must be provided "\
|
||||
)
|
||||
assert not zero_points or group_size is not None, (
|
||||
"to have group zero points, group_size must be provided "
|
||||
"(-1 group_size is channelwise)"
|
||||
)
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
@@ -376,14 +405,16 @@ def quantize_weights(w: torch.Tensor,
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
||||
.clamp(min_q_val, max_q_val).int()
|
||||
maybe_w_zp = (
|
||||
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
||||
)
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
||||
)
|
||||
|
||||
# Quantize
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
@@ -430,19 +461,22 @@ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: Optional[torch.Tensor] = None):
|
||||
def gptq_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: Optional[torch.Tensor] = None,
|
||||
):
|
||||
size_k, _ = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
|
||||
assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, (
|
||||
f"Unsupported gptq type = {quant_type}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
)
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [size_k], (
|
||||
f"Unsupported groupsize = {group_size}"
|
||||
)
|
||||
|
||||
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
||||
|
||||
@@ -450,13 +484,13 @@ def gptq_quantize_weights(w: torch.Tensor,
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
if act_order:
|
||||
assert (
|
||||
group_size < size_k
|
||||
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||
group_size, size_k)
|
||||
assert group_size < size_k, (
|
||||
"For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||
group_size, size_k
|
||||
)
|
||||
)
|
||||
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size,
|
||||
test_perm)
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
|
||||
|
||||
return w_ref, w_q, w_s, g_idx, rand_perm
|
||||
|
||||
@@ -464,8 +498,7 @@ def gptq_quantize_weights(w: torch.Tensor,
|
||||
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||
orig_device = q_w.device
|
||||
|
||||
sort_indices = torch.argsort(g_idx).to(
|
||||
dtype=torch.int32) # Sort based on g_idx
|
||||
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
|
||||
|
||||
g_idx = g_idx[sort_indices].contiguous()
|
||||
q_w = q_w[sort_indices, :].contiguous()
|
||||
@@ -535,10 +568,11 @@ def unpack_cols(
|
||||
):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
assert packed_q_w.shape == (
|
||||
size_k, size_n // pack_factor
|
||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor)
|
||||
assert packed_q_w.shape == (size_k, size_n // pack_factor), (
|
||||
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor
|
||||
)
|
||||
)
|
||||
|
||||
orig_device = packed_q_w.device
|
||||
|
||||
@@ -604,7 +638,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
assert scale.dtype == torch.float8_e4m3fn, (
|
||||
"swizzle_blockscale expects the input tensor to be in "
|
||||
"torch.float8_e4m3fn format.")
|
||||
"torch.float8_e4m3fn format."
|
||||
)
|
||||
|
||||
scale_ndim = scale.ndim
|
||||
if scale_ndim == 2:
|
||||
@@ -619,9 +654,9 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
M_padded = _round_up(M, 128)
|
||||
K_padded = _round_up(K, 4)
|
||||
|
||||
padded = torch.zeros((B, M_padded, K_padded),
|
||||
dtype=scale.dtype,
|
||||
device=scale.device)
|
||||
padded = torch.zeros(
|
||||
(B, M_padded, K_padded), dtype=scale.dtype, device=scale.device
|
||||
)
|
||||
padded[:B, :M, :K] = scale
|
||||
|
||||
# Reshape / permute to the layout required by the kernel.
|
||||
|
||||
Reference in New Issue
Block a user