[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -134,10 +134,8 @@ class TestTensors:
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
rank_tokens = (
|
||||
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
)
|
||||
device = torch.accelerator.current_device_index()
|
||||
rank_tokens = torch.randn((m, k), device=device, dtype=dtype) / 10.0
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
rank_token_scales = None
|
||||
|
||||
@@ -145,11 +143,13 @@ class TestTensors:
|
||||
low=0,
|
||||
high=config.num_experts,
|
||||
size=(m, topk),
|
||||
device=torch.cuda.current_device(),
|
||||
device=device,
|
||||
).to(dtype=torch.int64)
|
||||
|
||||
topk_weights = torch.randn(
|
||||
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
|
||||
topk_ids.shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return TestTensors(
|
||||
@@ -296,7 +296,8 @@ def deepep_deepgemm_moe_impl(
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
device = torch.accelerator.current_device_index()
|
||||
return expert_map.to(device=device, dtype=torch.int32)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
@@ -376,10 +377,11 @@ def _test_deepep_deepgemm_moe(
|
||||
|
||||
set_random_seed(pgi.rank)
|
||||
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
w1_scale = w1_scale.to(device=torch.cuda.current_device())
|
||||
w2_scale = w2_scale.to(device=torch.cuda.current_device())
|
||||
device = torch.accelerator.current_device_index()
|
||||
w1 = w1.to(device=device)
|
||||
w2 = w2.to(device=device)
|
||||
w1_scale = w1_scale.to(device=device)
|
||||
w2_scale = w2_scale.to(device=device)
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, pgi.rank)
|
||||
|
||||
Reference in New Issue
Block a user