[XPU]Bug fix for some unexpected error when use AgRs backend on XPU device. (#36593)
Signed-off-by: yisheng <yi.sheng@intel.com>
This commit is contained in:
@@ -70,7 +70,7 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
|||||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||||
)
|
)
|
||||||
|
|
||||||
dist.reduce_scatter_tensor(output, input_tensor)
|
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)
|
||||||
|
|
||||||
# Reshape before returning
|
# Reshape before returning
|
||||||
return output.movedim(0, dim).contiguous()
|
return output.movedim(0, dim).contiguous()
|
||||||
@@ -103,9 +103,9 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
|||||||
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
|
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
|
||||||
# if inputs shape in different ranks is not the same using reduce_scatter
|
# if inputs shape in different ranks is not the same using reduce_scatter
|
||||||
input_splits = list(input_tensor.split(sizes, dim=0))
|
input_splits = list(input_tensor.split(sizes, dim=0))
|
||||||
dist.reduce_scatter(output, input_splits)
|
dist.reduce_scatter(output, input_splits, group=self.device_group)
|
||||||
else:
|
else:
|
||||||
dist.reduce_scatter_tensor(output, input_tensor)
|
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)
|
||||||
# Reshape before returning
|
# Reshape before returning
|
||||||
return output.movedim(0, dim).contiguous()
|
return output.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
@@ -149,10 +149,10 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
|||||||
device=input_.device,
|
device=input_.device,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
dist.all_gather(all_gather_list, input_)
|
dist.all_gather(all_gather_list, input_, group=self.device_group)
|
||||||
output_tensor = torch.cat(all_gather_list, dim=0)
|
output_tensor = torch.cat(all_gather_list, dim=0)
|
||||||
else:
|
else:
|
||||||
dist.all_gather([output_tensor], input_)
|
dist.all_gather([output_tensor], input_, group=self.device_group)
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
if isinstance(input_, torch.Tensor):
|
if isinstance(input_, torch.Tensor):
|
||||||
|
|||||||
@@ -85,6 +85,9 @@ class XPUWorker(Worker):
|
|||||||
current_platform.dist_backend,
|
current_platform.dist_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# global all_reduce needed for overall oneccl warm up
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||||
|
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user