[TPU] Reduce compilation time & Upgrade PyTorch XLA version (#6856)
This commit is contained in:
@@ -6,6 +6,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
|
||||
|
||||
@@ -20,7 +21,7 @@ class TpuCommunicator:
|
||||
local_rank = dist.get_rank(group)
|
||||
world_size = dist.get_world_size(group)
|
||||
pjrt.initialize_multiprocess(local_rank, world_size)
|
||||
xm._init_world_size_ordinal()
|
||||
xr._init_world_size_ordinal()
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||
|
||||
Reference in New Issue
Block a user