[Quant] Make static quant support all group shapes (#30833)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-09 15:49:27 -05:00
committed by GitHub
parent f9e2a75a1e
commit 0a0aa07747
7 changed files with 338 additions and 46 deletions

View File

@@ -11,6 +11,10 @@ from tests.kernels.quant_utils import (
ref_dynamic_per_token_quant,
)
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
DTYPES = [torch.bfloat16, torch.float]
@@ -21,10 +25,18 @@ SEEDS = [0]
def opcheck_fp8_quant(
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
output,
input,
scale=None,
scale_ub=None,
use_per_token_if_dynamic=False,
group_shape=None,
):
if scale is not None:
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
opcheck(
torch.ops._C.static_scaled_fp8_quant,
(output, input, scale, group_shape),
)
elif use_per_token_if_dynamic:
scale = torch.empty(
(input.shape[0], 1), device=input.device, dtype=torch.float32
@@ -118,3 +130,92 @@ def test_fp8_quant_large(seed: int) -> None:
ops_out = ops_out.to(dtype=dtype)
torch.testing.assert_close(ref_out, ops_out)
# Test static FP8 quantization with 2D group scales
GROUP_SHAPES_2D = [
(-1, -1), # Per-tensor
(-1, 1), # Per-channel
(1, -1), # Per-token
(-1, 128), # Per-head quantization
(1, 128), # DeepSeek-style per-token-per-group (group_m=1, group_n=128)
(128, 128), # DeepSeek-style block quantization
(1, 64), # Smaller group size
(1, 16), # Small group (scalar path in kernel)
(4, 256), # Non-trivial both dimensions
]
# Use sizes divisible by all group shapes
NUM_TOKENS_GROUP = [128, 512]
HIDDEN_SIZES_GROUP = [256, 1024, 2048]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
@pytest.mark.parametrize("group_shape", GROUP_SHAPES_2D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_static_fp8_quant_group_2d(
num_tokens: int,
hidden_size: int,
group_shape: tuple[int, int],
dtype: torch.dtype,
seed: int,
) -> None:
"""Test static FP8 quantization with 2D group scales using scaled_quantize."""
# Normalize group_shape (-1 means full extent)
norm_group_m = num_tokens if group_shape[0] == -1 else group_shape[0]
norm_group_n = hidden_size if group_shape[1] == -1 else group_shape[1]
# Skip if sizes are not divisible by group shape
if num_tokens % norm_group_m != 0 or hidden_size % norm_group_n != 0:
pytest.skip(
f"Skipping: ({num_tokens}, {hidden_size}) not divisible by "
f"group_shape ({group_shape[0]}, {group_shape[1]})"
)
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = scaled_quantize(
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
)
ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape)
torch.testing.assert_close(scale, ops_scale)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
opcheck_fp8_quant(ops_out, x, scale=scale)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("group_shape", [(1, -1), (-1, 1)]) # per-token, per-channel
@torch.inference_mode()
def test_static_fp8_quant_1d_scale(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
seed: int,
group_shape: tuple[int, int],
) -> None:
"""Test static FP8 quantization with 1D scale (per-token or per-channel)."""
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale_2d = scaled_quantize(
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
)
# Flatten scale to 1D for testing 1D scale path
scale_1d = scale_2d.flatten()
ops_out, ops_scale = ops.scaled_fp8_quant(
x, scale=scale_1d, group_shape=group_shape
)
torch.testing.assert_close(scale_1d, ops_scale)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
opcheck_fp8_quant(ops_out, x, scale=scale_1d, group_shape=group_shape)