[Quant] Make static quant support all group shapes (#30833)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-09 15:49:27 -05:00
committed by GitHub
parent f9e2a75a1e
commit 0a0aa07747
7 changed files with 338 additions and 46 deletions

View File

@@ -158,11 +158,14 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
# If tensor has fewer dimensions than target shape, treat missing
# dimensions as size 1 (standard PyTorch broadcasting behavior)
t_dim_size = t.shape[i] if i < t.ndim else 1
if t_dim_size != s and t_dim_size != 1:
assert s % t_dim_size == 0
t = (
t.unsqueeze(i + 1)
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
.expand(*t.shape[: i + 1], s // t_dim_size, *t.shape[i + 1 :])
.flatten(i, i + 1)
)
return t
@@ -180,7 +183,16 @@ def scaled_quantize(
x: torch.Tensor,
group_shape: GroupShape,
quant_dtype: torch.dtype,
compute_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor to quantize
group_shape: Shape of quantization groups
quant_dtype: Target quantized dtype (e.g., torch.float8_e4m3fn)
compute_dtype: Optional dtype for intermediate computations.
If None, uses input dtype. Use torch.float32 for higher precision.
"""
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, (
"currently `scaled_quantize` only supports floating point dtypes "
@@ -189,11 +201,14 @@ def scaled_quantize(
finfo = torch.finfo(quant_dtype)
# Convert to compute dtype if specified
x_compute = x if compute_dtype is None else x.to(compute_dtype)
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
x_blkd = x_compute.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)