[Hardware] Replace torch.cuda.synchronize() api with torch.accelerator.synchronize (#36085)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-05 18:36:39 +08:00
committed by GitHub
parent 0bfa229bf1
commit 66a2209645
59 changed files with 158 additions and 161 deletions

View File

@@ -68,7 +68,7 @@ def worker_fn():
)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
@@ -93,11 +93,11 @@ def multiple_allreduce_worker_fn():
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 2).cpu().item()
@@ -121,11 +121,11 @@ def multiple_allreduce_with_vllm_worker_fn():
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 2).cpu().item()
@@ -147,12 +147,12 @@ def worker_fn_with_cudagraph():
)
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
torch.cuda.synchronize()
torch.accelerator.synchronize()
with torch.cuda.graph(graph):
a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize()
torch.accelerator.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
@@ -180,7 +180,7 @@ def all_gather_worker_fn():
).to(device)
pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@@ -215,7 +215,7 @@ def all_gatherv_worker_fn():
).to(device)
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@@ -255,7 +255,7 @@ def reduce_scatter_worker_fn():
).to(device)
pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@@ -293,7 +293,7 @@ def reduce_scatterv_worker_fn():
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@@ -325,7 +325,7 @@ def send_recv_worker_fn():
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(tensor == 1).cpu().item()
@@ -355,7 +355,7 @@ def multiple_send_recv_worker_fn():
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
torch.cuda.synchronize()
torch.accelerator.synchronize()
if torch.distributed.get_rank() in [0, 2]:
assert torch.all(tensor == 1).cpu().item()
else:
@@ -396,7 +396,7 @@ def broadcast_worker_fn():
pynccl_comm.broadcast(recv_tensors[i], src=i)
# the broadcast op might be launched in a different stream
# need to synchronize to make sure the tensor is ready
torch.cuda.synchronize()
torch.accelerator.synchronize()
assert torch.all(recv_tensors[i] == i).cpu().item()