[XPU]Bug fix for some unexpected error when use AgRs backend on XPU device. (#36593)

Signed-off-by: yisheng <yi.sheng@intel.com>
This commit is contained in:
YiSheng5
2026-03-11 17:17:46 +08:00
committed by GitHub
parent f4ae58b38b
commit c910eeb125
2 changed files with 8 additions and 5 deletions

View File

@@ -70,7 +70,7 @@ class XpuCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
dist.reduce_scatter_tensor(output, input_tensor)
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)
# Reshape before returning
return output.movedim(0, dim).contiguous()
@@ -103,9 +103,9 @@ class XpuCommunicator(DeviceCommunicatorBase):
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits)
dist.reduce_scatter(output, input_splits, group=self.device_group)
else:
dist.reduce_scatter_tensor(output, input_tensor)
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)
# Reshape before returning
return output.movedim(0, dim).contiguous()
@@ -149,10 +149,10 @@ class XpuCommunicator(DeviceCommunicatorBase):
device=input_.device,
)
)
dist.all_gather(all_gather_list, input_)
dist.all_gather(all_gather_list, input_, group=self.device_group)
output_tensor = torch.cat(all_gather_list, dim=0)
else:
dist.all_gather([output_tensor], input_)
dist.all_gather([output_tensor], input_, group=self.device_group)
return output_tensor
if isinstance(input_, torch.Tensor):

View File

@@ -85,6 +85,9 @@ class XPUWorker(Worker):
current_platform.dist_backend,
)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
# Set random seed.
set_random_seed(self.model_config.seed)