[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)

Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-12 22:57:47 +08:00
committed by GitHub
parent 2e693f48e7
commit 53ec16a705
89 changed files with 254 additions and 219 deletions

View File

@@ -106,7 +106,7 @@ def mock_create_engine(config, parallel_config):
@create_new_process_for_each_test()
def test_get_world_size_tp1():
"""Test world_size is correctly configured for TP=1."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
llm = LLM(
@@ -125,7 +125,7 @@ def test_get_world_size_tp1():
def test_init_weight_transfer_engine_calls_engine():
"""Test that init_weight_transfer_engine calls the engine's
init_transfer_engine method."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Run in-process so mock.patch works (spawn won't inherit the mock)
@@ -174,7 +174,7 @@ def test_init_weight_transfer_engine_calls_engine():
@create_new_process_for_each_test()
def test_update_weights_calls_engine():
"""Test that update_weights calls the engine's receive_weights method."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Run in-process so mock.patch works (spawn won't inherit the mock)
@@ -233,7 +233,7 @@ def test_update_weights_calls_engine():
@create_new_process_for_each_test()
def test_full_weight_transfer_flow():
"""Test the complete weight transfer flow: init -> update."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Run in-process so mock.patch works (spawn won't inherit the mock)
@@ -294,7 +294,7 @@ def test_full_weight_transfer_flow():
@create_new_process_for_each_test()
def test_weight_transfer_config_backend():
"""Test that WeightTransferConfig backend is properly configured."""
if torch.cuda.device_count() < 1:
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Test with nccl backend