2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2025-02-02 14:58:18 -05:00
|
|
|
|
2024-09-23 13:46:26 -04:00
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from tests.kernels.utils import opcheck
|
|
|
|
|
from vllm._custom_ops import permute_cols
|
|
|
|
|
|
|
|
|
|
|
2025-10-05 15:06:22 +01:00
|
|
|
@pytest.mark.parametrize("shape", [(1, 512), (544, 4096), (67, 8192)])
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
2024-09-23 13:46:26 -04:00
|
|
|
def test_permute_cols(shape, dtype):
|
|
|
|
|
x = torch.randn(shape, dtype=dtype).cuda()
|
|
|
|
|
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
|
|
|
|
|
opcheck(torch.ops._C.permute_cols, (x, perm))
|
|
|
|
|
y = permute_cols(x, perm)
|
2025-10-05 15:06:22 +01:00
|
|
|
torch.testing.assert_close(y, x[:, perm])
|