[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -210,7 +210,8 @@ def deep_ep_moe_impl(
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
device = torch.accelerator.current_device_index()
|
||||
return expert_map.to(device=device, dtype=torch.int32)
|
||||
|
||||
hidden_size = test_tensors.rank_tokens.size(1)
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
@@ -365,15 +366,13 @@ def _deep_ep_moe(
|
||||
)
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
device_idx = torch.accelerator.current_device_index()
|
||||
w1 = w1.to(device=device_idx)
|
||||
w2 = w2.to(device=device_idx)
|
||||
if is_quantized:
|
||||
w1_scale = w1_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
w2_scale = w2_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
assert w1_scale is not None and w2_scale is not None
|
||||
w1_scale = w1_scale.to(device=device_idx)
|
||||
w2_scale = w2_scale.to(device=device_idx)
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, low_latency_mode)
|
||||
|
||||
Reference in New Issue
Block a user