[TPU] Deprecate xm.mark_step in favor of `torch_xla.sync (#25254)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Nicolò Lucchesi
2025-09-22 12:12:57 +02:00
committed by GitHub
parent a66d131381
commit 4cf71cc88a
5 changed files with 31 additions and 29 deletions

View File

@@ -211,16 +211,15 @@ class DefaultModelLoader(BaseModelLoader):
from vllm.platforms.tpu import USE_TPU_COMMONS
if not USE_TPU_COMMONS:
# In PyTorch XLA, we should call `xm.mark_step`
# In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated
# in the XLA program. import torch_xla.core.xla_model
# as xm
import torch_xla.core.xla_model as xm
# in the XLA program.
import torch_xla
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
torch_xla.sync(wait=False)
weights_iterator = _xla_weights_iterator(weights_iterator)