[Performance] Move apply_w8a8_block_fp8_linear to an op class (#24666)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <elizaw.9289@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
ElizaWszola
2025-09-23 21:03:10 +02:00
committed by GitHub
parent 8c1c81a3de
commit 63400259d0
14 changed files with 341 additions and 201 deletions

View File

@@ -20,9 +20,11 @@ from vllm.platforms import current_platform
(8, 513, 64), # Non-divisible (native only)
])
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int) -> None:
group_size: int, seed: int,
use_ue8m0: bool) -> None:
"""Test QuantFP8 group quantization with various configurations.
Tests both CUDA and native implementations, column-major scales,
@@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
column_major_scales=False,
use_ue8m0=use_ue8m0)
# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
@@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
column_major_scales=True,
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (expected_num_groups, batch_size)
assert scales_col.shape == (batch_size, expected_num_groups)
assert scales_col.stride(0) == 1
assert scales_col.stride(1) == batch_size
# Test column-major scales consistency
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
# 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible:
@@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int) -> None:
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
current_platform.seed_everything(seed)
group_size = 64
@@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
column_major_scales=False,
use_ue8m0=use_ue8m0)
x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
@@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
column_major_scales=True,
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)