[TPU] Deprecate xm.mark_step in favor of `torch_xla.sync (#25254)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
@@ -211,16 +211,15 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
|
||||
if not USE_TPU_COMMONS:
|
||||
# In PyTorch XLA, we should call `xm.mark_step`
|
||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||
# frequently so that not too many ops are accumulated
|
||||
# in the XLA program. import torch_xla.core.xla_model
|
||||
# as xm
|
||||
import torch_xla.core.xla_model as xm
|
||||
# in the XLA program.
|
||||
import torch_xla
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user