Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,13 +11,22 @@ from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
|
||||
8199] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [
|
||||
8,
|
||||
768,
|
||||
769,
|
||||
770,
|
||||
771,
|
||||
5120,
|
||||
5124,
|
||||
5125,
|
||||
5126,
|
||||
8192,
|
||||
8199,
|
||||
] # Arbitrary values for testing
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@@ -63,11 +72,14 @@ def test_rms_norm(
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
if residual is not None:
|
||||
opcheck(torch.ops._C.fused_add_rms_norm,
|
||||
(x, residual, layer.weight.data, layer.variance_epsilon))
|
||||
opcheck(
|
||||
torch.ops._C.fused_add_rms_norm,
|
||||
(x, residual, layer.weight.data, layer.variance_epsilon),
|
||||
)
|
||||
else:
|
||||
opcheck(torch.ops._C.rms_norm,
|
||||
(out, x, layer.weight.data, layer.variance_epsilon))
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm, (out, x, layer.weight.data, layer.variance_epsilon)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@@ -98,7 +110,8 @@ def test_poly_norm(
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.poly_norm,
|
||||
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon))
|
||||
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@@ -144,7 +157,8 @@ def test_fused_rms_norm_quant(
|
||||
|
||||
if add_residual:
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)
|
||||
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6
|
||||
)
|
||||
|
||||
# Unfused kernel is in-place so it goes second
|
||||
# Also use a separate clone of x to avoid modifying the input
|
||||
@@ -152,29 +166,32 @@ def test_fused_rms_norm_quant(
|
||||
x_unfused = x_unfused_base[..., :hidden_size]
|
||||
assert x_unfused.is_contiguous() != strided_input
|
||||
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
|
||||
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(),
|
||||
quant_scale_t)
|
||||
torch.ops._C.static_scaled_fp8_quant(
|
||||
out_quant, x_unfused.contiguous(), quant_scale_t
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(residual_fused,
|
||||
residual,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2)
|
||||
opcheck(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
|
||||
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6),
|
||||
)
|
||||
else:
|
||||
torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight,
|
||||
quant_scale_t, 1e-6)
|
||||
torch.ops._C.rms_norm_static_fp8_quant(
|
||||
out_quant_fused, x, weight, quant_scale_t, 1e-6
|
||||
)
|
||||
|
||||
torch.ops._C.rms_norm(out_norm, x, weight, 1e-6)
|
||||
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm,
|
||||
quant_scale_t)
|
||||
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t)
|
||||
|
||||
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, weight, quant_scale_t, 1e-6),
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out_quant.to(dtype=torch.float32),
|
||||
out_quant_fused.to(dtype=torch.float32),
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
out_quant.to(dtype=torch.float32),
|
||||
out_quant_fused.to(dtype=torch.float32),
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user