[Quant] Make static quant support all group shapes (#30833)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
get_fp8_min_max,
|
||||
group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -22,7 +23,7 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
||||
@CustomOp.register("quant_fp8")
|
||||
class QuantFP8(CustomOp):
|
||||
"""
|
||||
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
|
||||
Quantize input tensor to FP8 (per-tensor, per-token, per-channel, or per-group).
|
||||
This CustomOp supports both static and dynamic quantization.
|
||||
"""
|
||||
|
||||
@@ -57,14 +58,14 @@ class QuantFP8(CustomOp):
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
assert not static, "Group quantization only supports dynamic mode"
|
||||
self.group_size = group_shape.col
|
||||
else:
|
||||
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
|
||||
assert not static or group_shape == GroupShape.PER_TENSOR, (
|
||||
"Only per-tensor scales supported for static quantization."
|
||||
)
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
if not static:
|
||||
assert group_shape in (GroupShape.PER_TOKEN, GroupShape.PER_TENSOR), (
|
||||
"Only per-token or per-tensor scales are supported for dynamic "
|
||||
"non-group quantization."
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -72,8 +73,8 @@ class QuantFP8(CustomOp):
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Group quantization is always dynamic"
|
||||
if self.is_group_quant and not self.static:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
||||
|
||||
return fp8_utils.per_token_group_quant_fp8(
|
||||
@@ -90,12 +91,14 @@ class QuantFP8(CustomOp):
|
||||
and self.group_shape == GroupShape.PER_TOKEN
|
||||
and scale_ub.numel() == 1
|
||||
)
|
||||
|
||||
return ops.scaled_fp8_quant(
|
||||
x,
|
||||
scale,
|
||||
num_token_padding=self.num_token_padding,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
group_shape=self.group_shape if self.static else None,
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
@@ -131,8 +134,8 @@ class QuantFP8(CustomOp):
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
):
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Group quantization is always dynamic"
|
||||
if self.is_group_quant and not self.static:
|
||||
assert scale is None, "Dynamic group quantization does not use scale"
|
||||
return self._quantize_group_native(x)
|
||||
|
||||
assert (scale is not None) == self.static
|
||||
@@ -155,7 +158,10 @@ class QuantFP8(CustomOp):
|
||||
|
||||
# Even for dynamic per-token scales,
|
||||
# reciprocal performs slightly better than division
|
||||
out = x.to(torch.float32) * scale.reciprocal()
|
||||
out = (
|
||||
x.to(torch.float32)
|
||||
* group_broadcast(scale.to(torch.float32), x.shape[-2:]).reciprocal()
|
||||
)
|
||||
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
# This currently generates an extra Triton kernel in compilation.
|
||||
|
||||
@@ -158,11 +158,14 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# with an extent of 1, since this can be done implicitly by pytorch
|
||||
def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
if t.shape[i] != s and t.shape[i] != 1:
|
||||
assert s % t.shape[i] == 0
|
||||
# If tensor has fewer dimensions than target shape, treat missing
|
||||
# dimensions as size 1 (standard PyTorch broadcasting behavior)
|
||||
t_dim_size = t.shape[i] if i < t.ndim else 1
|
||||
if t_dim_size != s and t_dim_size != 1:
|
||||
assert s % t_dim_size == 0
|
||||
t = (
|
||||
t.unsqueeze(i + 1)
|
||||
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
|
||||
.expand(*t.shape[: i + 1], s // t_dim_size, *t.shape[i + 1 :])
|
||||
.flatten(i, i + 1)
|
||||
)
|
||||
return t
|
||||
@@ -180,7 +183,16 @@ def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: GroupShape,
|
||||
quant_dtype: torch.dtype,
|
||||
compute_dtype: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor to quantize
|
||||
group_shape: Shape of quantization groups
|
||||
quant_dtype: Target quantized dtype (e.g., torch.float8_e4m3fn)
|
||||
compute_dtype: Optional dtype for intermediate computations.
|
||||
If None, uses input dtype. Use torch.float32 for higher precision.
|
||||
"""
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, (
|
||||
"currently `scaled_quantize` only supports floating point dtypes "
|
||||
@@ -189,11 +201,14 @@ def scaled_quantize(
|
||||
|
||||
finfo = torch.finfo(quant_dtype)
|
||||
|
||||
# Convert to compute dtype if specified
|
||||
x_compute = x if compute_dtype is None else x.to(compute_dtype)
|
||||
|
||||
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
|
||||
assert x.ndim == 2
|
||||
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
|
||||
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
|
||||
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
|
||||
x_blkd = x_compute.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
|
||||
|
||||
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
|
||||
|
||||
Reference in New Issue
Block a user