Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -3,16 +3,20 @@
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from vllm._custom_ops import scaled_fp4_quant
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
FP4_DTYPE = torch.uint8
FP8_DTYPE = current_platform.fp8_dtype()
@@ -30,24 +34,24 @@ def test_silu_mul_nvfp4_quant(
shape: tuple[int, int],
) -> None:
current_platform.seed_everything(42)
device = 'cuda:0'
device = "cuda:0"
torch.set_default_device(device)
x = torch.randn(shape, dtype=dtype)
# ref op
ref_output = SiluAndMul().forward_native(x)
ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.abs(ref_output).max().to(torch.float32))
ref_output_quant, ref_block_scale = scaled_fp4_quant(
ref_output, ref_global_scale)
ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(
ref_output
).max().to(torch.float32)
ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale)
# fused op
fused_output_quant = torch.empty_like(ref_output_quant)
fused_block_scale = torch.empty_like(ref_block_scale)
torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant,
fused_block_scale, x,
ref_global_scale)
torch.ops._C.silu_and_mul_nvfp4_quant(
fused_output_quant, fused_block_scale, x, ref_global_scale
)
# check dtype
assert ref_output_quant.dtype == FP4_DTYPE
@@ -59,17 +63,14 @@ def test_silu_mul_nvfp4_quant(
assert ref_block_scale.shape == fused_block_scale.shape
# check dequantized output
ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant,
ref_block_scale,
ref_global_scale, dtype,
device)
fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant,
fused_block_scale,
ref_global_scale, dtype,
device)
ref_output_dequant = dequantize_nvfp4_to_dtype(
ref_output_quant, ref_block_scale, ref_global_scale, dtype, device
)
fused_output_dequant = dequantize_nvfp4_to_dtype(
fused_output_quant, fused_block_scale, ref_global_scale, dtype, device
)
atol, rtol = 3e-1, 3e-1
torch.testing.assert_close(ref_output_dequant,
fused_output_dequant,
atol=atol,
rtol=rtol)
torch.testing.assert_close(
ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol
)