[TPU] optimize the all-reduce performance (#15903)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -84,6 +84,12 @@ class TPUWorker:
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
# ring, the xla tpu compiler flag
|
||||
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user