Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.platforms import current_platform
|
||||
def check_cuda_context():
|
||||
"""Check CUDA driver context status"""
|
||||
try:
|
||||
cuda = ctypes.CDLL('libcuda.so')
|
||||
cuda = ctypes.CDLL("libcuda.so")
|
||||
device = ctypes.c_int()
|
||||
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
||||
return (True, device.value) if result == 0 else (False, None)
|
||||
@@ -27,9 +27,11 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||
# New thread should have no CUDA context initially
|
||||
valid_before, device_before = check_cuda_context()
|
||||
if valid_before:
|
||||
return False, \
|
||||
"CUDA context should not exist in new thread, " \
|
||||
f"got device {device_before}"
|
||||
return (
|
||||
False,
|
||||
"CUDA context should not exist in new thread, "
|
||||
f"got device {device_before}",
|
||||
)
|
||||
|
||||
# Test setting CUDA context
|
||||
current_platform.set_device(device_input)
|
||||
@@ -39,8 +41,7 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||
if not valid_after:
|
||||
return False, "CUDA context should be valid after set_cuda_context"
|
||||
if device_id != expected_device_id:
|
||||
return False, \
|
||||
f"Expected device {expected_device_id}, got {device_id}"
|
||||
return False, f"Expected device {expected_device_id}, got {device_id}"
|
||||
|
||||
return True, "Success"
|
||||
except Exception as e:
|
||||
@@ -50,30 +51,30 @@ def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||
class TestSetCudaContext:
|
||||
"""Test suite for the set_cuda_context function."""
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="CUDA not available")
|
||||
@pytest.mark.parametrize(argnames="device_input,expected_device_id",
|
||||
argvalues=[
|
||||
(0, 0),
|
||||
(torch.device('cuda:0'), 0),
|
||||
('cuda:0', 0),
|
||||
],
|
||||
ids=["int", "torch_device", "string"])
|
||||
def test_set_cuda_context_parametrized(self, device_input,
|
||||
expected_device_id):
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize(
|
||||
argnames="device_input,expected_device_id",
|
||||
argvalues=[
|
||||
(0, 0),
|
||||
(torch.device("cuda:0"), 0),
|
||||
("cuda:0", 0),
|
||||
],
|
||||
ids=["int", "torch_device", "string"],
|
||||
)
|
||||
def test_set_cuda_context_parametrized(self, device_input, expected_device_id):
|
||||
"""Test setting CUDA context in isolated threads."""
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_cuda_test_in_thread, device_input,
|
||||
expected_device_id)
|
||||
future = executor.submit(
|
||||
run_cuda_test_in_thread, device_input, expected_device_id
|
||||
)
|
||||
success, message = future.result(timeout=30)
|
||||
assert success, message
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="CUDA not available")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
def test_set_cuda_context_invalid_device_type(self):
|
||||
"""Test error handling for invalid device type."""
|
||||
with pytest.raises(ValueError, match="Expected a cuda device"):
|
||||
current_platform.set_device(torch.device('cpu'))
|
||||
current_platform.set_device(torch.device("cpu"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user