[Performance] Move apply_w8a8_block_fp8_linear to an op class (#24666)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
|
||||
@@ -90,7 +90,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
|
||||
@@ -20,9 +20,11 @@ from vllm.platforms import current_platform
|
||||
(8, 513, 64), # Non-divisible (native only)
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_size: int, seed: int) -> None:
|
||||
group_size: int, seed: int,
|
||||
use_ue8m0: bool) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
Tests both CUDA and native implementations, column-major scales,
|
||||
@@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False)
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
|
||||
# 1. Test native implementation (always available)
|
||||
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||
@@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
# 2. Test column-major scales configuration
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True)
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||
assert scales_col.shape == (expected_num_groups, batch_size)
|
||||
assert scales_col.shape == (batch_size, expected_num_groups)
|
||||
assert scales_col.stride(0) == 1
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
@@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
group_size = 64
|
||||
@@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False)
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
|
||||
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||
assert x_quant.shape == x_3d.shape
|
||||
@@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
# Test column_major_scales with multi-dim
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True)
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user