Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,15 +8,27 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||
PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
|
||||
(90, 128), (150, 128), (150, 48), (90, 80)]
|
||||
PAD_SHAPES = [
|
||||
(90, 64),
|
||||
(150, 64),
|
||||
(128, 48),
|
||||
(128, 80),
|
||||
(150, 80),
|
||||
(90, 48),
|
||||
(90, 128),
|
||||
(150, 128),
|
||||
(150, 48),
|
||||
(90, 80),
|
||||
]
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
@@ -31,7 +43,22 @@ FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
# 0001 -> 0.5
|
||||
# 0000 -> 0
|
||||
E2M1_TO_FLOAT32 = [
|
||||
0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6.
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
@@ -74,8 +101,7 @@ def ref_nvfp4_quant(x, global_scale):
|
||||
assert x.ndim == 2
|
||||
m, n = x.shape
|
||||
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
|
||||
vec_max = torch.max(torch.abs(x), dim=-1,
|
||||
keepdim=True)[0].to(torch.float32)
|
||||
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
|
||||
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
|
||||
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
|
||||
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
|
||||
@@ -131,7 +157,7 @@ def test_quantize_to_fp4(
|
||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
dtype = torch.float16
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device('cuda:0')
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
m, n = pad_shape
|
||||
|
||||
|
||||
Reference in New Issue
Block a user