[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -35,7 +35,9 @@ NUM_BLOCKS = [1024, 10000]
|
||||
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
# We assume fp8 is always enabled for testing.
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
@@ -69,7 +71,7 @@ def test_reshape_and_cache(
|
||||
pytest.skip()
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.accelerator.set_device_index(device)
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
@@ -192,7 +194,7 @@ def test_reshape_and_cache_flash(
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.accelerator.set_device_index(device)
|
||||
assert implementation in ["cuda", "triton"]
|
||||
if implementation == "triton" and kv_cache_layout == "HND":
|
||||
pytest.skip("Triton implementation only supports NHD layout.")
|
||||
@@ -553,7 +555,7 @@ def test_concat_and_cache_mla(
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.accelerator.set_device_index(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
@@ -632,7 +634,7 @@ def test_concat_and_cache_ds_mla(
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.accelerator.set_device_index(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
@@ -744,7 +746,7 @@ def test_swap_blocks_mla(
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.accelerator.set_device_index(device)
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
|
||||
Reference in New Issue
Block a user