[Perf] Fix and reapply move apply w8a8 block fp8 linear to class (#25696)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <elizaw.9289@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
ElizaWszola
2025-10-02 21:35:13 +02:00
committed by GitHub
parent 3d5f1c8640
commit 502640c3f9
13 changed files with 412 additions and 200 deletions

View File

@@ -20,9 +20,11 @@ from vllm.platforms import current_platform
(8, 513, 64), # Non-divisible (native only)
])
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int) -> None:
group_size: int, seed: int,
use_ue8m0: bool) -> None:
"""Test QuantFP8 group quantization with various configurations.
Tests both CUDA and native implementations, column-major scales,
@@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
column_major_scales=False,
use_ue8m0=use_ue8m0)
# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
@@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
column_major_scales=True,
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (expected_num_groups, batch_size)
assert scales_col.shape == (batch_size, expected_num_groups)
assert scales_col.stride(0) == 1
assert scales_col.stride(1) == batch_size
# Test column-major scales consistency
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
# 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible:
@@ -68,21 +77,23 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int) -> None:
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
current_platform.seed_everything(seed)
group_size = 64
# Test with 3D input
batch1, batch2, hidden_dim = 4, 8, 512
batch1, batch2, hidden_dim = 4, 8, 1024
x_3d = torch.randn(
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
column_major_scales=False,
use_ue8m0=use_ue8m0)
x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
@@ -91,9 +102,10 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
column_major_scales=True,
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
# Test with 4D input
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256