[Quant] Make static quant support all group shapes (#30833)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -11,6 +11,10 @@ from tests.kernels.quant_utils import (
|
||||
ref_dynamic_per_token_quant,
|
||||
)
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_quantize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
@@ -21,10 +25,18 @@ SEEDS = [0]
|
||||
|
||||
|
||||
def opcheck_fp8_quant(
|
||||
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
|
||||
output,
|
||||
input,
|
||||
scale=None,
|
||||
scale_ub=None,
|
||||
use_per_token_if_dynamic=False,
|
||||
group_shape=None,
|
||||
):
|
||||
if scale is not None:
|
||||
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
|
||||
opcheck(
|
||||
torch.ops._C.static_scaled_fp8_quant,
|
||||
(output, input, scale, group_shape),
|
||||
)
|
||||
elif use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(input.shape[0], 1), device=input.device, dtype=torch.float32
|
||||
@@ -118,3 +130,92 @@ def test_fp8_quant_large(seed: int) -> None:
|
||||
ops_out = ops_out.to(dtype=dtype)
|
||||
|
||||
torch.testing.assert_close(ref_out, ops_out)
|
||||
|
||||
|
||||
# Test static FP8 quantization with 2D group scales
|
||||
GROUP_SHAPES_2D = [
|
||||
(-1, -1), # Per-tensor
|
||||
(-1, 1), # Per-channel
|
||||
(1, -1), # Per-token
|
||||
(-1, 128), # Per-head quantization
|
||||
(1, 128), # DeepSeek-style per-token-per-group (group_m=1, group_n=128)
|
||||
(128, 128), # DeepSeek-style block quantization
|
||||
(1, 64), # Smaller group size
|
||||
(1, 16), # Small group (scalar path in kernel)
|
||||
(4, 256), # Non-trivial both dimensions
|
||||
]
|
||||
# Use sizes divisible by all group shapes
|
||||
NUM_TOKENS_GROUP = [128, 512]
|
||||
HIDDEN_SIZES_GROUP = [256, 1024, 2048]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
|
||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES_2D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_static_fp8_quant_group_2d(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
group_shape: tuple[int, int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
"""Test static FP8 quantization with 2D group scales using scaled_quantize."""
|
||||
# Normalize group_shape (-1 means full extent)
|
||||
norm_group_m = num_tokens if group_shape[0] == -1 else group_shape[0]
|
||||
norm_group_n = hidden_size if group_shape[1] == -1 else group_shape[1]
|
||||
|
||||
# Skip if sizes are not divisible by group shape
|
||||
if num_tokens % norm_group_m != 0 or hidden_size % norm_group_n != 0:
|
||||
pytest.skip(
|
||||
f"Skipping: ({num_tokens}, {hidden_size}) not divisible by "
|
||||
f"group_shape ({group_shape[0]}, {group_shape[1]})"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale = scaled_quantize(
|
||||
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
|
||||
)
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape)
|
||||
|
||||
torch.testing.assert_close(scale, ops_scale)
|
||||
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x, scale=scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("group_shape", [(1, -1), (-1, 1)]) # per-token, per-channel
|
||||
@torch.inference_mode()
|
||||
def test_static_fp8_quant_1d_scale(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
group_shape: tuple[int, int],
|
||||
) -> None:
|
||||
"""Test static FP8 quantization with 1D scale (per-token or per-channel)."""
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale_2d = scaled_quantize(
|
||||
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Flatten scale to 1D for testing 1D scale path
|
||||
scale_1d = scale_2d.flatten()
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(
|
||||
x, scale=scale_1d, group_shape=group_shape
|
||||
)
|
||||
|
||||
torch.testing.assert_close(scale_1d, ops_scale)
|
||||
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x, scale=scale_1d, group_shape=group_shape)
|
||||
|
||||
Reference in New Issue
Block a user