[Hardware] Replace torch.cuda.synchronize() api with torch.accelerator.synchronize (#36085)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -32,7 +32,7 @@ pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes)
|
||||
print(f"Rank {rank} has pointers {pointers}")
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
if rank == 0:
|
||||
# the first rank tries to write to all buffers
|
||||
@@ -41,7 +41,7 @@ if rank == 0:
|
||||
lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes)
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
host_data = (ctypes.c_char * buffer_size_in_bytes)()
|
||||
|
||||
@@ -59,6 +59,6 @@ for p in pointers:
|
||||
print(f"Rank {rank} verified all buffers")
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
CustomAllreduce.free_shared_buffer(pointers)
|
||||
|
||||
Reference in New Issue
Block a user