[TPU] Reduce compilation time & Upgrade PyTorch XLA version (#6856)

This commit is contained in:
Woosuk Kwon
2024-07-27 10:28:33 -07:00
committed by GitHub
parent f954d0715c
commit fad5576c58
6 changed files with 24 additions and 7 deletions

View File

@@ -6,6 +6,7 @@ from vllm.platforms import current_platform
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
@@ -20,7 +21,7 @@ class TpuCommunicator:
local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
xm._init_world_size_ordinal()
xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)