[Hardware] Replace torch.cuda.synchronize() api with torch.accelerator.synchronize (#36085)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -48,7 +48,7 @@ def graph_allreduce(
|
||||
data = torch.zeros(1)
|
||||
data = data.to(device=device)
|
||||
torch.distributed.all_reduce(data, group=group)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
del data
|
||||
|
||||
# we use the first group to communicate once
|
||||
@@ -68,7 +68,7 @@ def graph_allreduce(
|
||||
inp2 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for i in range(num_communication):
|
||||
|
||||
Reference in New Issue
Block a user