fix pynccl reduce_scatter (#23648)
Co-authored-by: hongchao <hongchao@msh.team>
This commit is contained in:
@@ -152,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
dtype=input_tensor.dtype,
|
dtype=input_tensor.dtype,
|
||||||
device=input_tensor.device)
|
device=input_tensor.device)
|
||||||
|
|
||||||
pynccl_comm.reduce_scatter(output, input_)
|
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||||
|
|
||||||
# Reshape before returning
|
# Reshape before returning
|
||||||
return output.movedim(0, dim).contiguous()
|
return output.movedim(0, dim).contiguous()
|
||||||
@@ -186,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
device=input_tensor.device)
|
device=input_tensor.device)
|
||||||
|
|
||||||
if sizes is not None:
|
if sizes is not None:
|
||||||
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
|
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
|
||||||
else:
|
else:
|
||||||
pynccl_comm.reduce_scatter(output, input_)
|
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||||
|
|
||||||
# Reshape before returning
|
# Reshape before returning
|
||||||
return output.movedim(0, dim).contiguous()
|
return output.movedim(0, dim).contiguous()
|
||||||
|
|||||||
Reference in New Issue
Block a user