Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,19 +8,23 @@ import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
|
||||
GeluAndMul, MulAndSilu,
|
||||
NewGELU, QuickGELU,
|
||||
SiluAndMul, SwigluOAIAndMul)
|
||||
from vllm.model_executor.layers.activation import (
|
||||
FastGELU,
|
||||
FatreluAndMul,
|
||||
GeluAndMul,
|
||||
MulAndSilu,
|
||||
NewGELU,
|
||||
QuickGELU,
|
||||
SiluAndMul,
|
||||
SwigluOAIAndMul,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
D = [512, 13824] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -73,24 +77,19 @@ def test_act_and_mul(
|
||||
out = layer(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
if activation == "swigluoai_and_mul":
|
||||
|
||||
rtol = {
|
||||
#For fp16, change the relative tolerance from 1e-3 to 2e-3
|
||||
torch.float16:
|
||||
2e-3,
|
||||
torch.bfloat16:
|
||||
2e-2,
|
||||
torch.float:
|
||||
1.3e-6
|
||||
# For fp16, change the relative tolerance from 1e-3 to 2e-3
|
||||
torch.float16: 2e-3,
|
||||
torch.bfloat16: 2e-2,
|
||||
torch.float: 1.3e-6,
|
||||
}
|
||||
|
||||
def _get_rtol(output) -> float:
|
||||
return rtol[output.dtype]
|
||||
|
||||
torch.testing.assert_close(out,
|
||||
ref_out,
|
||||
atol=get_default_atol(out),
|
||||
rtol=_get_rtol(out))
|
||||
torch.testing.assert_close(
|
||||
out, ref_out, atol=get_default_atol(out), rtol=_get_rtol(out)
|
||||
)
|
||||
else:
|
||||
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
|
||||
# equivalent to the native PyTorch implementations, so we can do exact
|
||||
@@ -98,7 +97,7 @@ def test_act_and_mul(
|
||||
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if activation == "fatrelu":
|
||||
opcheck(fn, (out, x, threshold))
|
||||
@@ -108,9 +107,14 @@ def test_act_and_mul(
|
||||
opcheck(fn, (out, x))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
|
||||
(NewGELU, torch.ops._C.gelu_new),
|
||||
(QuickGELU, torch.ops._C.gelu_quick)])
|
||||
@pytest.mark.parametrize(
|
||||
"activation",
|
||||
[
|
||||
(FastGELU, torch.ops._C.gelu_fast),
|
||||
(NewGELU, torch.ops._C.gelu_new),
|
||||
(QuickGELU, torch.ops._C.gelu_quick),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@@ -132,10 +136,9 @@ def test_activation(
|
||||
fn = activation[1]
|
||||
out = layer(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
torch.testing.assert_close(out,
|
||||
ref_out,
|
||||
atol=get_default_atol(out),
|
||||
rtol=get_default_rtol(out))
|
||||
torch.testing.assert_close(
|
||||
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
|
||||
)
|
||||
|
||||
out = torch.empty_like(x)
|
||||
opcheck(fn, (out, x))
|
||||
|
||||
Reference in New Issue
Block a user