[Hardware][TPU] Initial support of model parallelism with single worker using SPMD (#18011)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Siyuan Liu
2025-06-02 17:06:20 -07:00
committed by GitHub
parent c57d577e8d
commit 9112b443a0
11 changed files with 605 additions and 72 deletions

View File

@@ -49,7 +49,9 @@ def _make_synced_weight_loader(original_weight_loader):
def _synced_weight_loader(param, *args, **kwargs):
original_weight_loader(param, *args, **kwargs)
torch._sync(param)
# torch._sync doesn't support, is not needed for CPU tensors.
if param.device != torch.device("cpu"):
torch._sync(param)
return _synced_weight_loader