[TPU] optimize the all-reduce performance (#15903)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-04-02 17:25:14 -07:00
committed by GitHub
parent 1b84eff03a
commit 01b6113659
3 changed files with 16 additions and 2 deletions

View File

@@ -84,6 +84,12 @@ class TPUWorker:
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)