[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)

Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-12 22:57:47 +08:00
committed by GitHub
parent 2e693f48e7
commit 53ec16a705
89 changed files with 254 additions and 219 deletions

View File

@@ -134,10 +134,8 @@ class TestTensors:
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
rank_tokens = (
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
)
device = torch.accelerator.current_device_index()
rank_tokens = torch.randn((m, k), device=device, dtype=dtype) / 10.0
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
rank_token_scales = None
@@ -145,11 +143,13 @@ class TestTensors:
low=0,
high=config.num_experts,
size=(m, topk),
device=torch.cuda.current_device(),
device=device,
).to(dtype=torch.int64)
topk_weights = torch.randn(
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
topk_ids.shape,
dtype=torch.float32,
device=device,
)
return TestTensors(
@@ -296,7 +296,8 @@ def deepep_deepgemm_moe_impl(
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
device = torch.accelerator.current_device_index()
return expert_map.to(device=device, dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
@@ -376,10 +377,11 @@ def _test_deepep_deepgemm_moe(
set_random_seed(pgi.rank)
w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device())
w1_scale = w1_scale.to(device=torch.cuda.current_device())
w2_scale = w2_scale.to(device=torch.cuda.current_device())
device = torch.accelerator.current_device_index()
w1 = w1.to(device=device)
w2 = w2.to(device=device)
w1_scale = w1_scale.to(device=device)
w2_scale = w2_scale.to(device=device)
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)