[TPU] optimize the all-reduce performance (#15903)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-04-02 17:25:14 -07:00
committed by GitHub
parent 1b84eff03a
commit 01b6113659
3 changed files with 16 additions and 2 deletions

View File

@@ -22,6 +22,8 @@ if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
@@ -79,9 +81,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
self.groups = create_optimized_replica_groups()
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, input_)
# TODO: Remove the groups specification after XLA compiler can support
# auto-reordering the ring order for all-reduce.
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."