[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -210,10 +210,9 @@ WIKITEXT_ACCURACY_CONFIGS = [
|
||||
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
if torch.cuda.device_count() < tp_size:
|
||||
pytest.skip(
|
||||
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
|
||||
)
|
||||
device_count = torch.accelerator.device_count()
|
||||
if device_count < tp_size:
|
||||
pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
|
||||
|
||||
task = "wikitext"
|
||||
rtol = 0.1
|
||||
@@ -246,10 +245,9 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
reason="Read access to huggingface.co/amd is required for this test.",
|
||||
)
|
||||
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
|
||||
if torch.cuda.device_count() < 8:
|
||||
pytest.skip(
|
||||
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
|
||||
)
|
||||
device_count = torch.accelerator.device_count()
|
||||
if device_count < 8:
|
||||
pytest.skip(f"This test requires >=8 gpus, got only {device_count}")
|
||||
|
||||
task = "gsm8k"
|
||||
rtol = 0.03
|
||||
|
||||
Reference in New Issue
Block a user