Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,11 +8,11 @@ from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import permute_cols
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
|
||||
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("shape", [(1, 512), (544, 4096), (67, 8192)])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_permute_cols(shape, dtype):
|
||||
x = torch.randn(shape, dtype=dtype).cuda()
|
||||
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
|
||||
opcheck(torch.ops._C.permute_cols, (x, perm))
|
||||
y = permute_cols(x, perm)
|
||||
torch.testing.assert_close(y, x[:, perm])
|
||||
torch.testing.assert_close(y, x[:, perm])
|
||||
|
||||
Reference in New Issue
Block a user