[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -322,7 +322,7 @@ class WeightTensors:
|
||||
)
|
||||
|
||||
def to_current_device(self):
|
||||
device = torch.cuda.current_device()
|
||||
device = torch.accelerator.current_device_index()
|
||||
self.w1 = self.w1.to(device=device)
|
||||
self.w2 = self.w2.to(device=device)
|
||||
|
||||
@@ -392,7 +392,8 @@ class RankTensors:
|
||||
Return hidden_states
|
||||
"""
|
||||
m, k, dtype = (config.M, config.K, config.dtype)
|
||||
a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0
|
||||
device = torch.accelerator.current_device_index()
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 15.0
|
||||
|
||||
if config.quant_dtype is None:
|
||||
return a, None
|
||||
@@ -428,9 +429,10 @@ class RankTensors:
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
device = torch.accelerator.current_device_index()
|
||||
for mi in range(m):
|
||||
topk_ids[mi] = torch.randperm(config.E)[:topk]
|
||||
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
||||
topk_ids = topk_ids.to(device=device)
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1 and config.supports_expert_map():
|
||||
@@ -440,9 +442,7 @@ class RankTensors:
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
expert_map = expert_map.to(
|
||||
device=torch.cuda.current_device(), dtype=torch.int32
|
||||
)
|
||||
expert_map = expert_map.to(device=device, dtype=torch.int32)
|
||||
|
||||
return RankTensors(
|
||||
hidden_states=hidden_states,
|
||||
@@ -558,7 +558,9 @@ def reference_moe_impl(
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones(
|
||||
(num_experts,), device=torch.cuda.current_device(), dtype=torch.float32
|
||||
(num_experts,),
|
||||
device=torch.accelerator.current_device_index(),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user