[TPU] optimize the all-reduce performance (#15903)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -22,6 +22,8 @@ if current_platform.is_tpu():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
from torch_xla.distributed.xla_multiprocessing import (
|
||||
create_optimized_replica_groups)
|
||||
|
||||
if USE_RAY:
|
||||
from vllm.executor import ray_utils
|
||||
@@ -79,9 +81,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||
xr._init_world_size_ordinal()
|
||||
self.groups = create_optimized_replica_groups()
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, input_)
|
||||
# TODO: Remove the groups specification after XLA compiler can support
|
||||
# auto-reordering the ring order for all-reduce.
|
||||
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||
|
||||
Reference in New Issue
Block a user