[Quant] Make static quant support all group shapes (#30833)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-09 15:49:27 -05:00
committed by GitHub
parent f9e2a75a1e
commit 0a0aa07747
7 changed files with 338 additions and 46 deletions

View File

@@ -10,6 +10,7 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
group_broadcast,
)
from vllm.platforms import current_platform
@@ -22,7 +23,7 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
@CustomOp.register("quant_fp8")
class QuantFP8(CustomOp):
"""
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
Quantize input tensor to FP8 (per-tensor, per-token, per-channel, or per-group).
This CustomOp supports both static and dynamic quantization.
"""
@@ -57,14 +58,14 @@ class QuantFP8(CustomOp):
self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode"
self.group_size = group_shape.col
else:
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, (
"Only per-tensor scales supported for static quantization."
)
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
if not static:
assert group_shape in (GroupShape.PER_TOKEN, GroupShape.PER_TENSOR), (
"Only per-token or per-tensor scales are supported for dynamic "
"non-group quantization."
)
def forward_cuda(
self,
@@ -72,8 +73,8 @@ class QuantFP8(CustomOp):
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
if self.is_group_quant and not self.static:
assert scale is None, "Dynamic group quantization does not use scale"
from vllm.model_executor.layers.quantization.utils import fp8_utils
return fp8_utils.per_token_group_quant_fp8(
@@ -90,12 +91,14 @@ class QuantFP8(CustomOp):
and self.group_shape == GroupShape.PER_TOKEN
and scale_ub.numel() == 1
)
return ops.scaled_fp8_quant(
x,
scale,
num_token_padding=self.num_token_padding,
scale_ub=scale_ub,
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
group_shape=self.group_shape if self.static else None,
)
def forward_hip(
@@ -131,8 +134,8 @@ class QuantFP8(CustomOp):
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
):
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
if self.is_group_quant and not self.static:
assert scale is None, "Dynamic group quantization does not use scale"
return self._quantize_group_native(x)
assert (scale is not None) == self.static
@@ -155,7 +158,10 @@ class QuantFP8(CustomOp):
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division
out = x.to(torch.float32) * scale.reciprocal()
out = (
x.to(torch.float32)
* group_broadcast(scale.to(torch.float32), x.shape[-2:]).reciprocal()
)
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
# This currently generates an extra Triton kernel in compilation.

View File

@@ -158,11 +158,14 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
# If tensor has fewer dimensions than target shape, treat missing
# dimensions as size 1 (standard PyTorch broadcasting behavior)
t_dim_size = t.shape[i] if i < t.ndim else 1
if t_dim_size != s and t_dim_size != 1:
assert s % t_dim_size == 0
t = (
t.unsqueeze(i + 1)
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
.expand(*t.shape[: i + 1], s // t_dim_size, *t.shape[i + 1 :])
.flatten(i, i + 1)
)
return t
@@ -180,7 +183,16 @@ def scaled_quantize(
x: torch.Tensor,
group_shape: GroupShape,
quant_dtype: torch.dtype,
compute_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor to quantize
group_shape: Shape of quantization groups
quant_dtype: Target quantized dtype (e.g., torch.float8_e4m3fn)
compute_dtype: Optional dtype for intermediate computations.
If None, uses input dtype. Use torch.float32 for higher precision.
"""
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, (
"currently `scaled_quantize` only supports floating point dtypes "
@@ -189,11 +201,14 @@ def scaled_quantize(
finfo = torch.finfo(quant_dtype)
# Convert to compute dtype if specified
x_compute = x if compute_dtype is None else x.to(compute_dtype)
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
x_blkd = x_compute.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)