[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)

Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-12 22:57:47 +08:00
committed by GitHub
parent 2e693f48e7
commit 53ec16a705
89 changed files with 254 additions and 219 deletions

View File

@@ -35,7 +35,9 @@ NUM_BLOCKS = [1024, 10000]
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]
@@ -69,7 +71,7 @@ def test_reshape_and_cache(
pytest.skip()
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
@@ -192,7 +194,7 @@ def test_reshape_and_cache_flash(
) -> None:
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
assert implementation in ["cuda", "triton"]
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")
@@ -553,7 +555,7 @@ def test_concat_and_cache_mla(
) -> None:
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
@@ -632,7 +634,7 @@ def test_concat_and_cache_ds_mla(
kv_cache_dtype = "fp8_ds_mla"
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
@@ -744,7 +746,7 @@ def test_swap_blocks_mla(
) -> None:
set_random_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
entry_size = kv_lora_rank + qk_rope_head_dim