[TPU] make ptxla not imported when using tpu_commons (#23081)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
Chengji Yao
2025-08-18 20:46:42 -07:00
committed by GitHub
parent a4454e9401
commit e9d6a3db69
6 changed files with 94 additions and 78 deletions

View File

@@ -207,16 +207,21 @@ class DefaultModelLoader(BaseModelLoader):
)
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
from vllm.platforms.tpu import USE_TPU_COMMONS
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
if not USE_TPU_COMMONS:
# In PyTorch XLA, we should call `xm.mark_step`
# requently so that not too many ops are accumulated
# in the XLA program. import torch_xla.core.xla_model
# as xm
import torch_xla.core.xla_model as xm
weights_iterator = _xla_weights_iterator(weights_iterator)
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()