[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)

Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-12 22:57:47 +08:00
committed by GitHub
parent 2e693f48e7
commit 53ec16a705
89 changed files with 254 additions and 219 deletions

View File

@@ -203,7 +203,7 @@ class TestEngineRegistry:
def test_nccl_receive_weights_without_init_raises():
"""Test that receive_weights raises if init_transfer_engine wasn't called."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="nccl")
@@ -336,7 +336,7 @@ def inference_receive_tensor(
@pytest.mark.skipif(
torch.cuda.device_count() < 2,
torch.accelerator.device_count() < 2,
reason="Need at least 2 GPUs to run NCCL weight transfer test.",
)
def test_nccl_weight_transfer_between_processes():
@@ -382,7 +382,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_valid_update_info(self):
"""Test creating valid IPCWeightTransferUpdateInfo."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Create a dummy tensor and IPC handle
@@ -404,7 +404,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_mismatched_dtype_names_raises(self):
"""Test that mismatched dtype_names length raises ValueError."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
@@ -422,7 +422,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_mismatched_shapes_raises(self):
"""Test that mismatched shapes length raises ValueError."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
@@ -440,7 +440,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_mismatched_ipc_handles_raises(self):
"""Test that mismatched ipc_handles length raises ValueError."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
@@ -458,7 +458,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_valid_update_info_from_pickled(self, monkeypatch):
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@@ -493,7 +493,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
def test_both_handles_and_pickled_raises(self):
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
@@ -540,7 +540,7 @@ class TestIPCEngineParsing:
def test_parse_update_info_valid(self):
"""Test parsing valid update info dict."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="ipc")
@@ -572,7 +572,7 @@ class TestIPCEngineParsing:
def test_parse_update_info_pickled(self, monkeypatch):
"""Test parsing update info with pickled IPC handles (HTTP path)."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@@ -731,7 +731,7 @@ def inference_receive_ipc_tensor(
@pytest.mark.skipif(
torch.cuda.device_count() < 1,
torch.accelerator.device_count() < 1,
reason="Need at least 1 GPU to run IPC weight transfer test.",
)
@pytest.mark.parametrize("mode", ["ray", "http"])
@@ -789,7 +789,7 @@ def test_ipc_weight_transfer_between_processes(mode: str):
def test_ipc_receive_weights_missing_gpu_uuid_raises():
"""Test that receive_weights raises if GPU UUID not found in IPC handles."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="ipc")