[XPU]Bug fix for some unexpected error when use AgRs backend on XPU device. (#36593)

Signed-off-by: yisheng <yi.sheng@intel.com>
This commit is contained in:
YiSheng5
2026-03-11 17:17:46 +08:00
committed by GitHub
parent f4ae58b38b
commit c910eeb125
2 changed files with 8 additions and 5 deletions

View File

@@ -70,7 +70,7 @@ class XpuCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device output_shape, dtype=input_tensor.dtype, device=input_tensor.device
) )
dist.reduce_scatter_tensor(output, input_tensor) dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
@@ -103,9 +103,9 @@ class XpuCommunicator(DeviceCommunicatorBase):
if sizes is not None and sizes.count(sizes[0]) != len(sizes): if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter # if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0)) input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits) dist.reduce_scatter(output, input_splits, group=self.device_group)
else: else:
dist.reduce_scatter_tensor(output, input_tensor) dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
@@ -149,10 +149,10 @@ class XpuCommunicator(DeviceCommunicatorBase):
device=input_.device, device=input_.device,
) )
) )
dist.all_gather(all_gather_list, input_) dist.all_gather(all_gather_list, input_, group=self.device_group)
output_tensor = torch.cat(all_gather_list, dim=0) output_tensor = torch.cat(all_gather_list, dim=0)
else: else:
dist.all_gather([output_tensor], input_) dist.all_gather([output_tensor], input_, group=self.device_group)
return output_tensor return output_tensor
if isinstance(input_, torch.Tensor): if isinstance(input_, torch.Tensor):

View File

@@ -85,6 +85,9 @@ class XPUWorker(Worker):
current_platform.dist_backend, current_platform.dist_backend,
) )
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)