Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,8 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -18,13 +17,14 @@ from vllm.platforms import current_platform
|
||||
(64, 1024, 64), # Medium
|
||||
(128, 2048, 128), # Large
|
||||
(8, 513, 64), # Non-divisible (native only)
|
||||
])
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_size: int, seed: int,
|
||||
use_ue8m0: bool) -> None:
|
||||
def test_quantfp8_group_functionality(
|
||||
batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
|
||||
) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
Tests both CUDA and native implementations, column-major scales,
|
||||
@@ -32,16 +32,17 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
"""
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.randn(
|
||||
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
expected_num_groups = (hidden_dim + group_size - 1) // group_size
|
||||
is_divisible = hidden_dim % group_size == 0
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
quant_op = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
# 1. Test native implementation (always available)
|
||||
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||
@@ -49,10 +50,12 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
assert scales_native.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# 2. Test column-major scales configuration
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
quant_op_col = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||
assert scales_col.shape == (batch_size, expected_num_groups)
|
||||
assert scales_col.stride(0) == 1
|
||||
@@ -86,41 +89,48 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
|
||||
# Test with 3D input
|
||||
batch1, batch2, hidden_dim = 4, 8, 1024
|
||||
x_3d = torch.randn(
|
||||
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
x_3d = (
|
||||
torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda")
|
||||
* 8
|
||||
)
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
quant_op = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||
assert x_quant.shape == x_3d.shape
|
||||
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test column_major_scales with multi-dim
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
quant_op_col = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test with 4D input
|
||||
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
|
||||
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda") * 8
|
||||
x_4d = (
|
||||
torch.randn(
|
||||
(batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
* 8
|
||||
)
|
||||
|
||||
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
|
||||
assert x_quant_4d.shape == x_4d.shape
|
||||
assert scales_4d.shape == (batch1, batch2, batch3,
|
||||
hidden_dim // group_size)
|
||||
assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size)
|
||||
|
||||
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
|
||||
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
|
||||
batch3)
|
||||
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@@ -132,30 +142,24 @@ def test_quantfp8_group_edge_cases(seed: int) -> None:
|
||||
group_size = 64
|
||||
|
||||
# Test with single group (group_size >= hidden_dim)
|
||||
x_small = torch.randn(
|
||||
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
|
||||
x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False)
|
||||
quant_op = QuantFP8(
|
||||
static=False, group_shape=group_shape, column_major_scales=False
|
||||
)
|
||||
|
||||
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
|
||||
assert x_quant_small.shape == x_small.shape
|
||||
assert scales_small.shape == (batch_size, 1)
|
||||
|
||||
# Test with zero inputs
|
||||
x_zero = torch.zeros((batch_size, 256),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda")
|
||||
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
|
||||
assert x_quant_zero.shape == x_zero.shape
|
||||
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
|
||||
|
||||
# Test very large values
|
||||
x_large = torch.full((batch_size, 256),
|
||||
1000.0,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda")
|
||||
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
|
||||
assert x_quant_large.shape == x_large.shape
|
||||
# FP8 max is typically 448 or 224, so scales should be > 1
|
||||
|
||||
Reference in New Issue
Block a user