From c910eeb125003ebe19e0f4e6d27d335061597e81 Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Wed, 11 Mar 2026 17:17:46 +0800 Subject: [PATCH] [XPU]Bug fix for some unexpected error when use AgRs backend on XPU device. (#36593) Signed-off-by: yisheng --- .../device_communicators/xpu_communicator.py | 10 +++++----- vllm/v1/worker/xpu_worker.py | 3 +++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 85c7f18e3..d2e9e89e5 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -70,7 +70,7 @@ class XpuCommunicator(DeviceCommunicatorBase): output_shape, dtype=input_tensor.dtype, device=input_tensor.device ) - dist.reduce_scatter_tensor(output, input_tensor) + dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -103,9 +103,9 @@ class XpuCommunicator(DeviceCommunicatorBase): if sizes is not None and sizes.count(sizes[0]) != len(sizes): # if inputs shape in different ranks is not the same using reduce_scatter input_splits = list(input_tensor.split(sizes, dim=0)) - dist.reduce_scatter(output, input_splits) + dist.reduce_scatter(output, input_splits, group=self.device_group) else: - dist.reduce_scatter_tensor(output, input_tensor) + dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -149,10 +149,10 @@ class XpuCommunicator(DeviceCommunicatorBase): device=input_.device, ) ) - dist.all_gather(all_gather_list, input_) + dist.all_gather(all_gather_list, input_, group=self.device_group) output_tensor = torch.cat(all_gather_list, dim=0) else: - dist.all_gather([output_tensor], input_) + dist.all_gather([output_tensor], input_, group=self.device_group) return output_tensor if isinstance(input_, torch.Tensor): diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 898c79087..112a71b37 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -85,6 +85,9 @@ class XPUWorker(Worker): current_platform.dist_backend, ) + # global all_reduce needed for overall oneccl warm up + torch.distributed.all_reduce(torch.zeros(1).xpu()) + # Set random seed. set_random_seed(self.model_config.seed)