[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)

Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-12 22:57:47 +08:00
committed by GitHub
parent 2e693f48e7
commit 53ec16a705
89 changed files with 254 additions and 219 deletions

View File

@@ -210,10 +210,9 @@ WIKITEXT_ACCURACY_CONFIGS = [
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
if torch.cuda.device_count() < tp_size:
pytest.skip(
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
)
device_count = torch.accelerator.device_count()
if device_count < tp_size:
pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
task = "wikitext"
rtol = 0.1
@@ -246,10 +245,9 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
reason="Read access to huggingface.co/amd is required for this test.",
)
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
if torch.cuda.device_count() < 8:
pytest.skip(
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
)
device_count = torch.accelerator.device_count()
if device_count < 8:
pytest.skip(f"This test requires >=8 gpus, got only {device_count}")
task = "gsm8k"
rtol = 0.03